Merge main into fix/chore-fix

pull/12372/head
Yeuoly 2 years ago
commit bedbd658fe
No known key found for this signature in database
GPG Key ID: A66E7E320FB19F61

2
.gitignore vendored

@ -175,6 +175,8 @@ docker/volumes/pgvector/data/*
docker/volumes/pgvecto_rs/data/* docker/volumes/pgvecto_rs/data/*
docker/nginx/conf.d/default.conf docker/nginx/conf.d/default.conf
docker/nginx/ssl/*
!docker/nginx/ssl/.gitkeep
docker/middleware.env docker/middleware.env
sdks/python-client/build sdks/python-client/build

@ -7,7 +7,8 @@ Dify is licensed under the Apache License 2.0, with the following additional con
a. Multi-tenant service: Unless explicitly authorized by Dify in writing, you may not use the Dify source code to operate a multi-tenant environment. a. Multi-tenant service: Unless explicitly authorized by Dify in writing, you may not use the Dify source code to operate a multi-tenant environment.
- Tenant Definition: Within the context of Dify, one tenant corresponds to one workspace. The workspace provides a separated area for each tenant's data and configurations. - Tenant Definition: Within the context of Dify, one tenant corresponds to one workspace. The workspace provides a separated area for each tenant's data and configurations.
b. LOGO and copyright information: In the process of using Dify's frontend components, you may not remove or modify the LOGO or copyright information in the Dify console or applications. This restriction is inapplicable to uses of Dify that do not involve its frontend components. b. LOGO and copyright information: In the process of using Dify's frontend, you may not remove or modify the LOGO or copyright information in the Dify console or applications. This restriction is inapplicable to uses of Dify that do not involve its frontend.
- Frontend Definition: For the purposes of this license, the "frontend" of Dify includes all components located in the `web/` directory when running Dify from the raw source code, or the "web" image when running Dify with Docker.
Please contact business@dify.ai by email to inquire about licensing matters. Please contact business@dify.ai by email to inquire about licensing matters.

@ -42,7 +42,7 @@ DB_DATABASE=dify
# Storage configuration # Storage configuration
# use for store upload files, private keys... # use for store upload files, private keys...
# storage type: local, s3, azure-blob, google-storage, tencent-cos, huawei-obs, volcengine-tos, baidu-obs, supabase # storage type: local, s3, aliyun-oss, azure-blob, baidu-obs, google-storage, huawei-obs, oci-storage, tencent-cos, volcengine-tos, supabase
STORAGE_TYPE=local STORAGE_TYPE=local
STORAGE_LOCAL_PATH=storage STORAGE_LOCAL_PATH=storage
S3_USE_AWS_MANAGED_IAM=false S3_USE_AWS_MANAGED_IAM=false
@ -233,6 +233,8 @@ VIKINGDB_SOCKET_TIMEOUT=30
UPLOAD_FILE_SIZE_LIMIT=15 UPLOAD_FILE_SIZE_LIMIT=15
UPLOAD_FILE_BATCH_LIMIT=5 UPLOAD_FILE_BATCH_LIMIT=5
UPLOAD_IMAGE_FILE_SIZE_LIMIT=10 UPLOAD_IMAGE_FILE_SIZE_LIMIT=10
UPLOAD_VIDEO_FILE_SIZE_LIMIT=100
UPLOAD_AUDIO_FILE_SIZE_LIMIT=50
# Model Configuration # Model Configuration
MULTIMODAL_SEND_IMAGE_FORMAT=base64 MULTIMODAL_SEND_IMAGE_FORMAT=base64
@ -310,6 +312,7 @@ INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH=1000
WORKFLOW_MAX_EXECUTION_STEPS=500 WORKFLOW_MAX_EXECUTION_STEPS=500
WORKFLOW_MAX_EXECUTION_TIME=1200 WORKFLOW_MAX_EXECUTION_TIME=1200
WORKFLOW_CALL_MAX_DEPTH=5 WORKFLOW_CALL_MAX_DEPTH=5
MAX_VARIABLE_SIZE=204800
# App configuration # App configuration
APP_MAX_EXECUTION_TIME=1200 APP_MAX_EXECUTION_TIME=1200
@ -338,3 +341,6 @@ INNER_API_KEY=QaHbTe77CtuXmsfyhR7+vRjI/+XbV1AaFy691iy+kGDv2Jvy0/eAh8Y1
# Marketplace configuration # Marketplace configuration
MARKETPLACE_ENABLED=true MARKETPLACE_ENABLED=true
MARKETPLACE_API_URL=https://marketplace.dify.ai MARKETPLACE_API_URL=https://marketplace.dify.ai
# Reset password token expiry minutes
RESET_PASSWORD_TOKEN_EXPIRY_MINUTES=5

@ -1,8 +1,15 @@
{ {
"version": "0.2.0", "version": "0.2.0",
"compounds": [
{
"name": "Launch Flask and Celery",
"configurations": ["Python: Flask", "Python: Celery"]
}
],
"configurations": [ "configurations": [
{ {
"name": "Python: Flask", "name": "Python: Flask",
"consoleName": "Flask",
"type": "debugpy", "type": "debugpy",
"request": "launch", "request": "launch",
"python": "${workspaceFolder}/.venv/bin/python", "python": "${workspaceFolder}/.venv/bin/python",
@ -17,12 +24,12 @@
}, },
"args": [ "args": [
"run", "run",
"--host=0.0.0.0",
"--port=5001" "--port=5001"
] ]
}, },
{ {
"name": "Python: Celery", "name": "Python: Celery",
"consoleName": "Celery",
"type": "debugpy", "type": "debugpy",
"request": "launch", "request": "launch",
"python": "${workspaceFolder}/.venv/bin/python", "python": "${workspaceFolder}/.venv/bin/python",
@ -45,10 +52,10 @@
"-c", "-c",
"1", "1",
"--loglevel", "--loglevel",
"info", "DEBUG",
"-Q", "-Q",
"dataset,generation,mail,ops_trace,app_deletion" "dataset,generation,mail,ops_trace,app_deletion"
] ]
}, }
] ]
} }

@ -55,7 +55,7 @@ RUN apt-get update \
&& echo "deb http://deb.debian.org/debian testing main" > /etc/apt/sources.list \ && echo "deb http://deb.debian.org/debian testing main" > /etc/apt/sources.list \
&& apt-get update \ && apt-get update \
# For Security # For Security
&& apt-get install -y --no-install-recommends zlib1g=1:1.3.dfsg+really1.3.1-1 expat=2.6.3-1 libldap-2.5-0=2.5.18+dfsg-3 perl=5.38.2-5 libsqlite3-0=3.46.0-1 \ && apt-get install -y --no-install-recommends zlib1g=1:1.3.dfsg+really1.3.1-1 expat=2.6.3-1 libldap-2.5-0=2.5.18+dfsg-3 perl=5.38.2-5 libsqlite3-0=3.46.1-1 \
&& apt-get autoremove -y \ && apt-get autoremove -y \
&& rm -rf /var/lib/apt/lists/* && rm -rf /var/lib/apt/lists/*

@ -10,44 +10,20 @@ if os.environ.get("DEBUG", "false").lower() != "true":
grpc.experimental.gevent.init_gevent() grpc.experimental.gevent.init_gevent()
import json import json
import logging
import sys
import threading import threading
import time import time
import warnings import warnings
from logging.handlers import RotatingFileHandler
from flask import Flask, Response, request from flask import Response
from flask_cors import CORS
from werkzeug.exceptions import Unauthorized
import contexts from app_factory import create_app
from commands import register_commands
from configs import dify_config
# DO NOT REMOVE BELOW # DO NOT REMOVE BELOW
from events import event_handlers # noqa: F401 from events import event_handlers # noqa: F401
from extensions import (
ext_celery,
ext_code_based_extension,
ext_compress,
ext_database,
ext_hosting_provider,
ext_login,
ext_mail,
ext_migrate,
ext_proxy_fix,
ext_redis,
ext_sentry,
ext_storage,
)
from extensions.ext_database import db from extensions.ext_database import db
from extensions.ext_login import login_manager
from libs.passport import PassportService
# TODO: Find a way to avoid importing models here # TODO: Find a way to avoid importing models here
from models import account, dataset, model, source, task, tool, tools, web # noqa: F401 from models import account, dataset, model, source, task, tool, tools, web # noqa: F401
from services.account_service import AccountService
# DO NOT REMOVE ABOVE # DO NOT REMOVE ABOVE
@ -60,188 +36,12 @@ if hasattr(time, "tzset"):
time.tzset() time.tzset()
class DifyApp(Flask):
pass
# ------------- # -------------
# Configuration # Configuration
# ------------- # -------------
config_type = os.getenv("EDITION", default="SELF_HOSTED") # ce edition first config_type = os.getenv("EDITION", default="SELF_HOSTED") # ce edition first
# ----------------------------
# Application Factory Function
# ----------------------------
def create_flask_app_with_configs() -> Flask:
"""
create a raw flask app
with configs loaded from .env file
"""
dify_app = DifyApp(__name__)
dify_app.config.from_mapping(dify_config.model_dump())
# populate configs into system environment variables
for key, value in dify_app.config.items():
if isinstance(value, str):
os.environ[key] = value
elif isinstance(value, int | float | bool):
os.environ[key] = str(value)
elif value is None:
os.environ[key] = ""
return dify_app
def create_app() -> Flask:
app = create_flask_app_with_configs()
app.secret_key = app.config["SECRET_KEY"]
log_handlers = None
log_file = app.config.get("LOG_FILE")
if log_file:
log_dir = os.path.dirname(log_file)
os.makedirs(log_dir, exist_ok=True)
log_handlers = [
RotatingFileHandler(
filename=log_file,
maxBytes=1024 * 1024 * 1024,
backupCount=5,
),
logging.StreamHandler(sys.stdout),
]
logging.basicConfig(
level=app.config.get("LOG_LEVEL"),
format=app.config.get("LOG_FORMAT"),
datefmt=app.config.get("LOG_DATEFORMAT"),
handlers=log_handlers,
force=True,
)
log_tz = app.config.get("LOG_TZ")
if log_tz:
from datetime import datetime
import pytz
timezone = pytz.timezone(log_tz)
def time_converter(seconds):
return datetime.utcfromtimestamp(seconds).astimezone(timezone).timetuple()
for handler in logging.root.handlers:
handler.formatter.converter = time_converter
initialize_extensions(app)
register_blueprints(app)
register_commands(app)
return app
def initialize_extensions(app):
# Since the application instance is now created, pass it to each Flask
# extension instance to bind it to the Flask application instance (app)
ext_compress.init_app(app)
ext_code_based_extension.init()
ext_database.init_app(app)
ext_migrate.init(app, db)
ext_redis.init_app(app)
ext_storage.init_app(app)
ext_celery.init_app(app)
ext_login.init_app(app)
ext_mail.init_app(app)
ext_hosting_provider.init_app(app)
ext_sentry.init_app(app)
ext_proxy_fix.init_app(app)
# Flask-Login configuration
@login_manager.request_loader
def load_user_from_request(request_from_flask_login):
"""Load user based on the request."""
if request.blueprint not in {"console", "inner_api"}:
return None
# Check if the user_id contains a dot, indicating the old format
auth_header = request.headers.get("Authorization", "")
if not auth_header:
auth_token = request.args.get("_token")
if not auth_token:
raise Unauthorized("Invalid Authorization token.")
else:
if " " not in auth_header:
raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.")
auth_scheme, auth_token = auth_header.split(None, 1)
auth_scheme = auth_scheme.lower()
if auth_scheme != "bearer":
raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.")
decoded = PassportService().verify(auth_token)
user_id = decoded.get("user_id")
logged_in_account = AccountService.load_logged_in_account(account_id=user_id)
if logged_in_account:
contexts.tenant_id.set(logged_in_account.current_tenant_id)
return logged_in_account
@login_manager.unauthorized_handler
def unauthorized_handler():
"""Handle unauthorized requests."""
return Response(
json.dumps({"code": "unauthorized", "message": "Unauthorized."}),
status=401,
content_type="application/json",
)
# register blueprint routers
def register_blueprints(app):
from controllers.console import bp as console_app_bp
from controllers.files import bp as files_bp
from controllers.inner_api import bp as inner_api_bp
from controllers.service_api import bp as service_api_bp
from controllers.web import bp as web_bp
CORS(
service_api_bp,
allow_headers=["Content-Type", "Authorization", "X-App-Code"],
methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],
)
app.register_blueprint(service_api_bp)
CORS(
web_bp,
resources={r"/*": {"origins": app.config["WEB_API_CORS_ALLOW_ORIGINS"]}},
supports_credentials=True,
allow_headers=["Content-Type", "Authorization", "X-App-Code"],
methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],
expose_headers=["X-Version", "X-Env"],
)
app.register_blueprint(web_bp)
CORS(
console_app_bp,
resources={r"/*": {"origins": app.config["CONSOLE_CORS_ALLOW_ORIGINS"]}},
supports_credentials=True,
allow_headers=["Content-Type", "Authorization"],
methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],
expose_headers=["X-Version", "X-Env"],
)
app.register_blueprint(console_app_bp)
CORS(files_bp, allow_headers=["Content-Type"], methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"])
app.register_blueprint(files_bp)
app.register_blueprint(inner_api_bp)
# create app # create app
app = create_app() app = create_app()
celery = app.extensions["celery"] celery = app.extensions["celery"]

@ -0,0 +1,213 @@
import os
if os.environ.get("DEBUG", "false").lower() != "true":
from gevent import monkey
monkey.patch_all()
import grpc.experimental.gevent
grpc.experimental.gevent.init_gevent()
import json
import logging
import sys
from logging.handlers import RotatingFileHandler
from flask import Flask, Response, request
from flask_cors import CORS
from werkzeug.exceptions import Unauthorized
import contexts
from commands import register_commands
from configs import dify_config
from extensions import (
ext_celery,
ext_code_based_extension,
ext_compress,
ext_database,
ext_hosting_provider,
ext_login,
ext_mail,
ext_migrate,
ext_proxy_fix,
ext_redis,
ext_sentry,
ext_storage,
)
from extensions.ext_database import db
from extensions.ext_login import login_manager
from libs.passport import PassportService
from services.account_service import AccountService
class DifyApp(Flask):
pass
# ----------------------------
# Application Factory Function
# ----------------------------
def create_flask_app_with_configs() -> Flask:
"""
create a raw flask app
with configs loaded from .env file
"""
dify_app = DifyApp(__name__)
dify_app.config.from_mapping(dify_config.model_dump())
# populate configs into system environment variables
for key, value in dify_app.config.items():
if isinstance(value, str):
os.environ[key] = value
elif isinstance(value, int | float | bool):
os.environ[key] = str(value)
elif value is None:
os.environ[key] = ""
return dify_app
def create_app() -> Flask:
app = create_flask_app_with_configs()
app.secret_key = app.config["SECRET_KEY"]
log_handlers = None
log_file = app.config.get("LOG_FILE")
if log_file:
log_dir = os.path.dirname(log_file)
os.makedirs(log_dir, exist_ok=True)
log_handlers = [
RotatingFileHandler(
filename=log_file,
maxBytes=1024 * 1024 * 1024,
backupCount=5,
),
logging.StreamHandler(sys.stdout),
]
logging.basicConfig(
level=app.config.get("LOG_LEVEL"),
format=app.config.get("LOG_FORMAT"),
datefmt=app.config.get("LOG_DATEFORMAT"),
handlers=log_handlers,
force=True,
)
log_tz = app.config.get("LOG_TZ")
if log_tz:
from datetime import datetime
import pytz
timezone = pytz.timezone(log_tz)
def time_converter(seconds):
return datetime.utcfromtimestamp(seconds).astimezone(timezone).timetuple()
for handler in logging.root.handlers:
handler.formatter.converter = time_converter
initialize_extensions(app)
register_blueprints(app)
register_commands(app)
return app
def initialize_extensions(app):
# Since the application instance is now created, pass it to each Flask
# extension instance to bind it to the Flask application instance (app)
ext_compress.init_app(app)
ext_code_based_extension.init()
ext_database.init_app(app)
ext_migrate.init(app, db)
ext_redis.init_app(app)
ext_storage.init_app(app)
ext_celery.init_app(app)
ext_login.init_app(app)
ext_mail.init_app(app)
ext_hosting_provider.init_app(app)
ext_sentry.init_app(app)
ext_proxy_fix.init_app(app)
# Flask-Login configuration
@login_manager.request_loader
def load_user_from_request(request_from_flask_login):
"""Load user based on the request."""
if request.blueprint not in {"console", "inner_api"}:
return None
# Check if the user_id contains a dot, indicating the old format
auth_header = request.headers.get("Authorization", "")
if not auth_header:
auth_token = request.args.get("_token")
if not auth_token:
raise Unauthorized("Invalid Authorization token.")
else:
if " " not in auth_header:
raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.")
auth_scheme, auth_token = auth_header.split(None, 1)
auth_scheme = auth_scheme.lower()
if auth_scheme != "bearer":
raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.")
decoded = PassportService().verify(auth_token)
user_id = decoded.get("user_id")
logged_in_account = AccountService.load_logged_in_account(account_id=user_id)
if logged_in_account:
contexts.tenant_id.set(logged_in_account.current_tenant_id)
return logged_in_account
@login_manager.unauthorized_handler
def unauthorized_handler():
"""Handle unauthorized requests."""
return Response(
json.dumps({"code": "unauthorized", "message": "Unauthorized."}),
status=401,
content_type="application/json",
)
# register blueprint routers
def register_blueprints(app):
from controllers.console import bp as console_app_bp
from controllers.files import bp as files_bp
from controllers.inner_api import bp as inner_api_bp
from controllers.service_api import bp as service_api_bp
from controllers.web import bp as web_bp
CORS(
service_api_bp,
allow_headers=["Content-Type", "Authorization", "X-App-Code"],
methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],
)
app.register_blueprint(service_api_bp)
CORS(
web_bp,
resources={r"/*": {"origins": app.config["WEB_API_CORS_ALLOW_ORIGINS"]}},
supports_credentials=True,
allow_headers=["Content-Type", "Authorization", "X-App-Code"],
methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],
expose_headers=["X-Version", "X-Env"],
)
app.register_blueprint(web_bp)
CORS(
console_app_bp,
resources={r"/*": {"origins": app.config["CONSOLE_CORS_ALLOW_ORIGINS"]}},
supports_credentials=True,
allow_headers=["Content-Type", "Authorization"],
methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],
expose_headers=["X-Version", "X-Env"],
)
app.register_blueprint(console_app_bp)
CORS(files_bp, allow_headers=["Content-Type"], methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"])
app.register_blueprint(files_bp)
app.register_blueprint(inner_api_bp)

@ -19,7 +19,7 @@ from extensions.ext_redis import redis_client
from libs.helper import email as email_validate from libs.helper import email as email_validate
from libs.password import hash_password, password_pattern, valid_password from libs.password import hash_password, password_pattern, valid_password
from libs.rsa import generate_key_pair from libs.rsa import generate_key_pair
from models.account import Tenant from models import Tenant
from models.dataset import Dataset, DatasetCollectionBinding, DocumentSegment from models.dataset import Dataset, DatasetCollectionBinding, DocumentSegment
from models.dataset import Document as DatasetDocument from models.dataset import Document as DatasetDocument
from models.model import Account, App, AppAnnotationSetting, AppMode, Conversation, MessageAnnotation from models.model import Account, App, AppAnnotationSetting, AppMode, Conversation, MessageAnnotation
@ -259,6 +259,25 @@ def migrate_knowledge_vector_database():
skipped_count = 0 skipped_count = 0
total_count = 0 total_count = 0
vector_type = dify_config.VECTOR_STORE vector_type = dify_config.VECTOR_STORE
upper_colletion_vector_types = {
VectorType.MILVUS,
VectorType.PGVECTOR,
VectorType.RELYT,
VectorType.WEAVIATE,
VectorType.ORACLE,
VectorType.ELASTICSEARCH,
}
lower_colletion_vector_types = {
VectorType.ANALYTICDB,
VectorType.CHROMA,
VectorType.MYSCALE,
VectorType.PGVECTO_RS,
VectorType.TIDB_VECTOR,
VectorType.OPENSEARCH,
VectorType.TENCENT,
VectorType.BAIDU,
VectorType.VIKINGDB,
}
page = 1 page = 1
while True: while True:
try: try:
@ -284,11 +303,9 @@ def migrate_knowledge_vector_database():
skipped_count = skipped_count + 1 skipped_count = skipped_count + 1
continue continue
collection_name = "" collection_name = ""
if vector_type == VectorType.WEAVIATE:
dataset_id = dataset.id dataset_id = dataset.id
if vector_type in upper_colletion_vector_types:
collection_name = Dataset.gen_collection_name_by_id(dataset_id) collection_name = Dataset.gen_collection_name_by_id(dataset_id)
index_struct_dict = {"type": VectorType.WEAVIATE, "vector_store": {"class_prefix": collection_name}}
dataset.index_struct = json.dumps(index_struct_dict)
elif vector_type == VectorType.QDRANT: elif vector_type == VectorType.QDRANT:
if dataset.collection_binding_id: if dataset.collection_binding_id:
dataset_collection_binding = ( dataset_collection_binding = (
@ -301,63 +318,15 @@ def migrate_knowledge_vector_database():
else: else:
raise ValueError("Dataset Collection Binding not found") raise ValueError("Dataset Collection Binding not found")
else: else:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id) collection_name = Dataset.gen_collection_name_by_id(dataset_id)
index_struct_dict = {"type": VectorType.QDRANT, "vector_store": {"class_prefix": collection_name}}
dataset.index_struct = json.dumps(index_struct_dict)
elif vector_type == VectorType.MILVUS: elif vector_type in lower_colletion_vector_types:
dataset_id = dataset.id collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower()
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
index_struct_dict = {"type": VectorType.MILVUS, "vector_store": {"class_prefix": collection_name}}
dataset.index_struct = json.dumps(index_struct_dict)
elif vector_type == VectorType.RELYT:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
index_struct_dict = {"type": "relyt", "vector_store": {"class_prefix": collection_name}}
dataset.index_struct = json.dumps(index_struct_dict)
elif vector_type == VectorType.TENCENT:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
index_struct_dict = {"type": VectorType.TENCENT, "vector_store": {"class_prefix": collection_name}}
dataset.index_struct = json.dumps(index_struct_dict)
elif vector_type == VectorType.PGVECTOR:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
index_struct_dict = {"type": VectorType.PGVECTOR, "vector_store": {"class_prefix": collection_name}}
dataset.index_struct = json.dumps(index_struct_dict)
elif vector_type == VectorType.OPENSEARCH:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
index_struct_dict = {
"type": VectorType.OPENSEARCH,
"vector_store": {"class_prefix": collection_name},
}
dataset.index_struct = json.dumps(index_struct_dict)
elif vector_type == VectorType.ANALYTICDB:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
index_struct_dict = {
"type": VectorType.ANALYTICDB,
"vector_store": {"class_prefix": collection_name},
}
dataset.index_struct = json.dumps(index_struct_dict)
elif vector_type == VectorType.ELASTICSEARCH:
dataset_id = dataset.id
index_name = Dataset.gen_collection_name_by_id(dataset_id)
index_struct_dict = {"type": "elasticsearch", "vector_store": {"class_prefix": index_name}}
dataset.index_struct = json.dumps(index_struct_dict)
elif vector_type == VectorType.BAIDU:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
index_struct_dict = {
"type": VectorType.BAIDU,
"vector_store": {"class_prefix": collection_name},
}
dataset.index_struct = json.dumps(index_struct_dict)
else: else:
raise ValueError(f"Vector store {vector_type} is not supported.") raise ValueError(f"Vector store {vector_type} is not supported.")
index_struct_dict = {"type": vector_type, "vector_store": {"class_prefix": collection_name}}
dataset.index_struct = json.dumps(index_struct_dict)
vector = Vector(dataset) vector = Vector(dataset)
click.echo(f"Migrating dataset {dataset.id}.") click.echo(f"Migrating dataset {dataset.id}.")

@ -1,6 +1,15 @@
from typing import Annotated, Optional from typing import Annotated, Literal, Optional
from pydantic import AliasChoices, Field, HttpUrl, NegativeInt, NonNegativeInt, PositiveInt, computed_field from pydantic import (
AliasChoices,
Field,
HttpUrl,
NegativeInt,
NonNegativeInt,
PositiveFloat,
PositiveInt,
computed_field,
)
from pydantic_settings import BaseSettings from pydantic_settings import BaseSettings
from configs.feature.hosted_service import HostedServiceConfig from configs.feature.hosted_service import HostedServiceConfig
@ -11,16 +20,16 @@ class SecurityConfig(BaseSettings):
Security-related configurations for the application Security-related configurations for the application
""" """
SECRET_KEY: Optional[str] = Field( SECRET_KEY: str = Field(
description="Secret key for secure session cookie signing." description="Secret key for secure session cookie signing."
"Make sure you are changing this key for your deployment with a strong key." "Make sure you are changing this key for your deployment with a strong key."
"Generate a strong key using `openssl rand -base64 42` or set via the `SECRET_KEY` environment variable.", "Generate a strong key using `openssl rand -base64 42` or set via the `SECRET_KEY` environment variable.",
default=None, default="",
) )
RESET_PASSWORD_TOKEN_EXPIRY_HOURS: PositiveInt = Field( RESET_PASSWORD_TOKEN_EXPIRY_MINUTES: PositiveInt = Field(
description="Duration in hours for which a password reset token remains valid", description="Duration in minutes for which a password reset token remains valid",
default=24, default=5,
) )
@ -230,6 +239,16 @@ class FileUploadConfig(BaseSettings):
default=10, default=10,
) )
UPLOAD_VIDEO_FILE_SIZE_LIMIT: NonNegativeInt = Field(
description="video file size limit in Megabytes for uploading files",
default=100,
)
UPLOAD_AUDIO_FILE_SIZE_LIMIT: NonNegativeInt = Field(
description="audio file size limit in Megabytes for uploading files",
default=50,
)
BATCH_UPLOAD_LIMIT: NonNegativeInt = Field( BATCH_UPLOAD_LIMIT: NonNegativeInt = Field(
description="Maximum number of files allowed in a batch upload operation", description="Maximum number of files allowed in a batch upload operation",
default=20, default=20,
@ -408,8 +427,8 @@ class WorkflowConfig(BaseSettings):
) )
MAX_VARIABLE_SIZE: PositiveInt = Field( MAX_VARIABLE_SIZE: PositiveInt = Field(
description="Maximum size in bytes for a single variable in workflows. Default to 5KB.", description="Maximum size in bytes for a single variable in workflows. Default to 200 KB.",
default=5 * 1024, default=200 * 1024,
) )
@ -526,12 +545,18 @@ class MailConfig(BaseSettings):
default=False, default=False,
) )
EMAIL_SEND_IP_LIMIT_PER_MINUTE: PositiveInt = Field(
description="Maximum number of emails allowed to be sent from the same IP address in a minute",
default=50,
)
class RagEtlConfig(BaseSettings): class RagEtlConfig(BaseSettings):
""" """
Configuration for RAG ETL processes Configuration for RAG ETL processes
""" """
# TODO: This config is not only for rag etl, it is also for file upload, we should move it to file upload config
ETL_TYPE: str = Field( ETL_TYPE: str = Field(
description="RAG ETL type ('dify' or 'Unstructured'), default to 'dify'", description="RAG ETL type ('dify' or 'Unstructured'), default to 'dify'",
default="dify", default="dify",
@ -598,7 +623,7 @@ class IndexingConfig(BaseSettings):
class ImageFormatConfig(BaseSettings): class ImageFormatConfig(BaseSettings):
MULTIMODAL_SEND_IMAGE_FORMAT: str = Field( MULTIMODAL_SEND_IMAGE_FORMAT: Literal["base64", "url"] = Field(
description="Format for sending images in multimodal contexts ('base64' or 'url'), default is base64", description="Format for sending images in multimodal contexts ('base64' or 'url'), default is base64",
default="base64", default="base64",
) )
@ -667,6 +692,33 @@ class PositionConfig(BaseSettings):
return {item.strip() for item in self.POSITION_TOOL_EXCLUDES.split(",") if item.strip() != ""} return {item.strip() for item in self.POSITION_TOOL_EXCLUDES.split(",") if item.strip() != ""}
class LoginConfig(BaseSettings):
ENABLE_EMAIL_CODE_LOGIN: bool = Field(
description="whether to enable email code login",
default=False,
)
ENABLE_EMAIL_PASSWORD_LOGIN: bool = Field(
description="whether to enable email password login",
default=True,
)
ENABLE_SOCIAL_OAUTH_LOGIN: bool = Field(
description="whether to enable github/google oauth login",
default=False,
)
EMAIL_CODE_LOGIN_TOKEN_EXPIRY_MINUTES: PositiveInt = Field(
description="expiry time in minutes for email code login token",
default=5,
)
ALLOW_REGISTER: bool = Field(
description="whether to enable register",
default=False,
)
ALLOW_CREATE_WORKSPACE: bool = Field(
description="whether to enable create workspace",
default=False,
)
class FeatureConfig( class FeatureConfig(
# place the configs in alphabet order # place the configs in alphabet order
AppExecutionConfig, AppExecutionConfig,
@ -694,6 +746,7 @@ class FeatureConfig(
UpdateConfig, UpdateConfig,
WorkflowConfig, WorkflowConfig,
WorkspaceConfig, WorkspaceConfig,
LoginConfig,
# hosted services config # hosted services config
HostedServiceConfig, HostedServiceConfig,
CeleryBeatConfig, CeleryBeatConfig,

@ -35,7 +35,8 @@ from configs.middleware.vdb.weaviate_config import WeaviateConfig
class StorageConfig(BaseSettings): class StorageConfig(BaseSettings):
STORAGE_TYPE: str = Field( STORAGE_TYPE: str = Field(
description="Type of storage to use." description="Type of storage to use."
" Options: 'local', 's3', 'azure-blob', 'aliyun-oss', 'google-storage'. Default is 'local'.", " Options: 'local', 's3', 'aliyun-oss', 'azure-blob', 'baidu-obs', 'google-storage', 'huawei-obs', "
"'oci-storage', 'tencent-cos', 'volcengine-tos', 'supabase'. Default is 'local'.",
default="local", default="local",
) )

@ -9,7 +9,7 @@ class PackagingInfo(BaseSettings):
CURRENT_VERSION: str = Field( CURRENT_VERSION: str = Field(
description="Dify version", description="Dify version",
default="0.9.2", default="0.10.0",
) )
COMMIT_SHA: str = Field( COMMIT_SHA: str = Field(

@ -1,2 +1,22 @@
from configs import dify_config
HIDDEN_VALUE = "[__HIDDEN__]" HIDDEN_VALUE = "[__HIDDEN__]"
UUID_NIL = "00000000-0000-0000-0000-000000000000" UUID_NIL = "00000000-0000-0000-0000-000000000000"
IMAGE_EXTENSIONS = ["jpg", "jpeg", "png", "webp", "gif", "svg"]
IMAGE_EXTENSIONS.extend([ext.upper() for ext in IMAGE_EXTENSIONS])
VIDEO_EXTENSIONS = ["mp4", "mov", "mpeg", "mpga"]
VIDEO_EXTENSIONS.extend([ext.upper() for ext in VIDEO_EXTENSIONS])
AUDIO_EXTENSIONS = ["mp3", "m4a", "wav", "webm", "amr"]
AUDIO_EXTENSIONS.extend([ext.upper() for ext in AUDIO_EXTENSIONS])
if dify_config.ETL_TYPE == "Unstructured":
DOCUMENT_EXTENSIONS = ["txt", "markdown", "md", "pdf", "html", "htm", "xlsx", "xls"]
DOCUMENT_EXTENSIONS.extend(("docx", "csv", "eml", "msg", "pptx", "ppt", "xml", "epub"))
DOCUMENT_EXTENSIONS.extend([ext.upper() for ext in DOCUMENT_EXTENSIONS])
else:
DOCUMENT_EXTENSIONS = ["txt", "markdown", "md", "pdf", "html", "htm", "xlsx", "xls", "docx", "csv"]
DOCUMENT_EXTENSIONS.extend([ext.upper() for ext in DOCUMENT_EXTENSIONS])

@ -1,7 +1,9 @@
from contextvars import ContextVar from contextvars import ContextVar
from typing import TYPE_CHECKING
from core.workflow.entities.variable_pool import VariablePool if TYPE_CHECKING:
from core.workflow.entities.variable_pool import VariablePool
tenant_id: ContextVar[str] = ContextVar("tenant_id") tenant_id: ContextVar[str] = ContextVar("tenant_id")
workflow_variable_pool: ContextVar[VariablePool] = ContextVar("workflow_variable_pool") workflow_variable_pool: ContextVar["VariablePool"] = ContextVar("workflow_variable_pool")

@ -22,7 +22,8 @@ from fields.conversation_fields import (
) )
from libs.helper import DatetimeString from libs.helper import DatetimeString
from libs.login import login_required from libs.login import login_required
from models.model import AppMode, Conversation, EndUser, Message, MessageAnnotation from models import Conversation, EndUser, Message, MessageAnnotation
from models.model import AppMode
class CompletionConversationApi(Resource): class CompletionConversationApi(Resource):

@ -12,7 +12,7 @@ from controllers.console.wraps import account_initialization_required
from extensions.ext_database import db from extensions.ext_database import db
from fields.app_fields import app_site_fields from fields.app_fields import app_site_fields
from libs.login import login_required from libs.login import login_required
from models.model import Site from models import Site
def parse_app_site_args(): def parse_app_site_args():

@ -13,15 +13,15 @@ from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required from controllers.console.wraps import account_initialization_required
from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.segments import factory from factories import variable_factory
from core.errors.error import AppInvokeQuotaExceededError
from fields.workflow_fields import workflow_fields from fields.workflow_fields import workflow_fields
from fields.workflow_run_fields import workflow_run_node_execution_fields from fields.workflow_run_fields import workflow_run_node_execution_fields
from libs import helper from libs import helper
from libs.helper import TimestampField, uuid_value from libs.helper import TimestampField, uuid_value
from libs.login import current_user, login_required from libs.login import current_user, login_required
from models import App
from models.account import Account from models.account import Account
from models.model import App, AppMode from models.model import AppMode
from services.app_dsl_service import AppDslService from services.app_dsl_service import AppDslService
from services.app_generate_service import AppGenerateService from services.app_generate_service import AppGenerateService
from services.errors.app import WorkflowHashNotEqualError from services.errors.app import WorkflowHashNotEqualError
@ -105,9 +105,13 @@ class DraftWorkflowApi(Resource):
try: try:
environment_variables_list = args.get("environment_variables") or [] environment_variables_list = args.get("environment_variables") or []
environment_variables = [factory.build_variable_from_mapping(obj) for obj in environment_variables_list] environment_variables = [
variable_factory.build_variable_from_mapping(obj) for obj in environment_variables_list
]
conversation_variables_list = args.get("conversation_variables") or [] conversation_variables_list = args.get("conversation_variables") or []
conversation_variables = [factory.build_variable_from_mapping(obj) for obj in conversation_variables_list] conversation_variables = [
variable_factory.build_variable_from_mapping(obj) for obj in conversation_variables_list
]
workflow = workflow_service.sync_draft_workflow( workflow = workflow_service.sync_draft_workflow(
app_model=app_model, app_model=app_model,
graph=args["graph"], graph=args["graph"],
@ -292,17 +296,15 @@ class DraftWorkflowRunApi(Resource):
parser.add_argument("files", type=list, required=False, location="json") parser.add_argument("files", type=list, required=False, location="json")
args = parser.parse_args() args = parser.parse_args()
try:
response = AppGenerateService.generate( response = AppGenerateService.generate(
app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.DEBUGGER, streaming=True app_model=app_model,
user=current_user,
args=args,
invoke_from=InvokeFrom.DEBUGGER,
streaming=True,
) )
return helper.compact_generate_response(response) return helper.compact_generate_response(response)
except (ValueError, AppInvokeQuotaExceededError) as e:
raise e
except Exception as e:
logging.exception("internal server error.")
raise InternalServerError()
class WorkflowTaskStopApi(Resource): class WorkflowTaskStopApi(Resource):

@ -7,7 +7,8 @@ from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required from controllers.console.wraps import account_initialization_required
from fields.workflow_app_log_fields import workflow_app_log_pagination_fields from fields.workflow_app_log_fields import workflow_app_log_pagination_fields
from libs.login import login_required from libs.login import login_required
from models.model import App, AppMode from models import App
from models.model import AppMode
from services.workflow_app_service import WorkflowAppService from services.workflow_app_service import WorkflowAppService

@ -13,7 +13,8 @@ from fields.workflow_run_fields import (
) )
from libs.helper import uuid_value from libs.helper import uuid_value
from libs.login import login_required from libs.login import login_required
from models.model import App, AppMode from models import App
from models.model import AppMode
from services.workflow_run_service import WorkflowRunService from services.workflow_run_service import WorkflowRunService

@ -13,8 +13,8 @@ from controllers.console.wraps import account_initialization_required
from extensions.ext_database import db from extensions.ext_database import db
from libs.helper import DatetimeString from libs.helper import DatetimeString
from libs.login import login_required from libs.login import login_required
from models.enums import WorkflowRunTriggeredFrom
from models.model import AppMode from models.model import AppMode
from models.workflow import WorkflowRunTriggeredFrom
class WorkflowDailyRunsStatistic(Resource): class WorkflowDailyRunsStatistic(Resource):

@ -5,7 +5,8 @@ from typing import Optional, Union
from controllers.console.app.error import AppNotFoundError from controllers.console.app.error import AppNotFoundError
from extensions.ext_database import db from extensions.ext_database import db
from libs.login import current_user from libs.login import current_user
from models.model import App, AppMode from models import App
from models.model import AppMode
def get_app_model(view: Optional[Callable] = None, *, mode: Union[AppMode, list[AppMode], None] = None): def get_app_model(view: Optional[Callable] = None, *, mode: Union[AppMode, list[AppMode], None] = None):

@ -1,17 +1,15 @@
import base64
import datetime import datetime
import secrets
from flask import request
from flask_restful import Resource, reqparse from flask_restful import Resource, reqparse
from constants.languages import supported_language from constants.languages import supported_language
from controllers.console import api from controllers.console import api
from controllers.console.error import AlreadyActivateError from controllers.console.error import AlreadyActivateError
from extensions.ext_database import db from extensions.ext_database import db
from libs.helper import StrLen, email, timezone from libs.helper import StrLen, email, extract_remote_ip, timezone
from libs.password import hash_password, valid_password from models.account import AccountStatus, Tenant
from models.account import AccountStatus from services.account_service import AccountService, RegisterService
from services.account_service import RegisterService
class ActivateCheckApi(Resource): class ActivateCheckApi(Resource):
@ -27,8 +25,18 @@ class ActivateCheckApi(Resource):
token = args["token"] token = args["token"]
invitation = RegisterService.get_invitation_if_token_valid(workspaceId, reg_email, token) invitation = RegisterService.get_invitation_if_token_valid(workspaceId, reg_email, token)
if invitation:
return {"is_valid": invitation is not None, "workspace_name": invitation["tenant"].name if invitation else None} data = invitation.get("data", {})
tenant: Tenant = invitation.get("tenant", None)
workspace_name = tenant.name if tenant else None
workspace_id = tenant.id if tenant else None
invitee_email = data.get("email") if data else None
return {
"is_valid": invitation is not None,
"data": {"workspace_name": workspace_name, "workspace_id": workspace_id, "email": invitee_email},
}
else:
return {"is_valid": False}
class ActivateApi(Resource): class ActivateApi(Resource):
@ -38,7 +46,6 @@ class ActivateApi(Resource):
parser.add_argument("email", type=email, required=False, nullable=True, location="json") parser.add_argument("email", type=email, required=False, nullable=True, location="json")
parser.add_argument("token", type=str, required=True, nullable=False, location="json") parser.add_argument("token", type=str, required=True, nullable=False, location="json")
parser.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json") parser.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json")
parser.add_argument("password", type=valid_password, required=True, nullable=False, location="json")
parser.add_argument( parser.add_argument(
"interface_language", type=supported_language, required=True, nullable=False, location="json" "interface_language", type=supported_language, required=True, nullable=False, location="json"
) )
@ -54,15 +61,6 @@ class ActivateApi(Resource):
account = invitation["account"] account = invitation["account"]
account.name = args["name"] account.name = args["name"]
# generate password salt
salt = secrets.token_bytes(16)
base64_salt = base64.b64encode(salt).decode()
# encrypt password with salt
password_hashed = hash_password(args["password"], salt)
base64_password_hashed = base64.b64encode(password_hashed).decode()
account.password = base64_password_hashed
account.password_salt = base64_salt
account.interface_language = args["interface_language"] account.interface_language = args["interface_language"]
account.timezone = args["timezone"] account.timezone = args["timezone"]
account.interface_theme = "light" account.interface_theme = "light"
@ -70,7 +68,9 @@ class ActivateApi(Resource):
account.initialized_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) account.initialized_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
db.session.commit() db.session.commit()
return {"result": "success"} token_pair = AccountService.login(account, ip_address=extract_remote_ip(request))
return {"result": "success", "data": token_pair.model_dump()}
api.add_resource(ActivateCheckApi, "/activate/check") api.add_resource(ActivateCheckApi, "/activate/check")

@ -27,5 +27,29 @@ class InvalidTokenError(BaseHTTPException):
class PasswordResetRateLimitExceededError(BaseHTTPException): class PasswordResetRateLimitExceededError(BaseHTTPException):
error_code = "password_reset_rate_limit_exceeded" error_code = "password_reset_rate_limit_exceeded"
description = "Password reset rate limit exceeded. Try again later." description = "Too many password reset emails have been sent. Please try again in 1 minutes."
code = 429
class EmailCodeError(BaseHTTPException):
error_code = "email_code_error"
description = "Email code is invalid or expired."
code = 400
class EmailOrPasswordMismatchError(BaseHTTPException):
error_code = "email_or_password_mismatch"
description = "The email or password is mismatched."
code = 400
class EmailPasswordLoginLimitError(BaseHTTPException):
error_code = "email_code_login_limit"
description = "Too many incorrect password attempts. Please try again later."
code = 429
class EmailCodeLoginRateLimitExceededError(BaseHTTPException):
error_code = "email_code_login_rate_limit_exceeded"
description = "Too many login emails have been sent. Please try again in 5 minutes."
code = 429 code = 429

@ -1,65 +1,82 @@
import base64 import base64
import logging
import secrets import secrets
from flask import request
from flask_restful import Resource, reqparse from flask_restful import Resource, reqparse
from constants.languages import languages
from controllers.console import api from controllers.console import api
from controllers.console.auth.error import ( from controllers.console.auth.error import (
EmailCodeError,
InvalidEmailError, InvalidEmailError,
InvalidTokenError, InvalidTokenError,
PasswordMismatchError, PasswordMismatchError,
PasswordResetRateLimitExceededError,
) )
from controllers.console.error import EmailSendIpLimitError, NotAllowedRegister
from controllers.console.setup import setup_required from controllers.console.setup import setup_required
from events.tenant_event import tenant_was_created
from extensions.ext_database import db from extensions.ext_database import db
from libs.helper import email as email_validate from libs.helper import email, extract_remote_ip
from libs.password import hash_password, valid_password from libs.password import hash_password, valid_password
from models.account import Account from models.account import Account
from services.account_service import AccountService from services.account_service import AccountService, TenantService
from services.errors.account import RateLimitExceededError from services.errors.workspace import WorkSpaceNotAllowedCreateError
from services.feature_service import FeatureService
class ForgotPasswordSendEmailApi(Resource): class ForgotPasswordSendEmailApi(Resource):
@setup_required @setup_required
def post(self): def post(self):
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("email", type=str, required=True, location="json") parser.add_argument("email", type=email, required=True, location="json")
parser.add_argument("language", type=str, required=False, location="json")
args = parser.parse_args() args = parser.parse_args()
email = args["email"] ip_address = extract_remote_ip(request)
if AccountService.is_email_send_ip_limit(ip_address):
if not email_validate(email): raise EmailSendIpLimitError()
raise InvalidEmailError()
account = Account.query.filter_by(email=email).first()
if account: if args["language"] is not None and args["language"] == "zh-Hans":
try: language = "zh-Hans"
AccountService.send_reset_password_email(account=account)
except RateLimitExceededError:
logging.warning(f"Rate limit exceeded for email: {account.email}")
raise PasswordResetRateLimitExceededError()
else: else:
# Return success to avoid revealing email registration status language = "en-US"
logging.warning(f"Attempt to reset password for unregistered email: {email}")
account = Account.query.filter_by(email=args["email"]).first()
token = None
if account is None:
if FeatureService.get_system_features().is_allow_register:
token = AccountService.send_reset_password_email(email=args["email"], language=language)
return {"result": "fail", "data": token, "code": "account_not_found"}
else:
raise NotAllowedRegister()
else:
token = AccountService.send_reset_password_email(account=account, email=args["email"], language=language)
return {"result": "success"} return {"result": "success", "data": token}
class ForgotPasswordCheckApi(Resource): class ForgotPasswordCheckApi(Resource):
@setup_required @setup_required
def post(self): def post(self):
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("email", type=str, required=True, location="json")
parser.add_argument("code", type=str, required=True, location="json")
parser.add_argument("token", type=str, required=True, nullable=False, location="json") parser.add_argument("token", type=str, required=True, nullable=False, location="json")
args = parser.parse_args() args = parser.parse_args()
token = args["token"]
reset_data = AccountService.get_reset_password_data(token) user_email = args["email"]
if reset_data is None: token_data = AccountService.get_reset_password_data(args["token"])
return {"is_valid": False, "email": None} if token_data is None:
return {"is_valid": True, "email": reset_data.get("email")} raise InvalidTokenError()
if user_email != token_data.get("email"):
raise InvalidEmailError()
if args["code"] != token_data.get("code"):
raise EmailCodeError()
return {"is_valid": True, "email": token_data.get("email")}
class ForgotPasswordResetApi(Resource): class ForgotPasswordResetApi(Resource):
@ -92,9 +109,26 @@ class ForgotPasswordResetApi(Resource):
base64_password_hashed = base64.b64encode(password_hashed).decode() base64_password_hashed = base64.b64encode(password_hashed).decode()
account = Account.query.filter_by(email=reset_data.get("email")).first() account = Account.query.filter_by(email=reset_data.get("email")).first()
if account:
account.password = base64_password_hashed account.password = base64_password_hashed
account.password_salt = base64_salt account.password_salt = base64_salt
db.session.commit() db.session.commit()
tenant = TenantService.get_join_tenants(account)
if not tenant and not FeatureService.get_system_features().is_allow_create_workspace:
tenant = TenantService.create_tenant(f"{account.name}'s Workspace")
TenantService.create_tenant_member(tenant, account, role="owner")
account.current_tenant = tenant
tenant_was_created.send(tenant)
else:
try:
account = AccountService.create_account_and_tenant(
email=reset_data.get("email"),
name=reset_data.get("email"),
password=password_confirm,
interface_language=languages[0],
)
except WorkSpaceNotAllowedCreateError:
pass
return {"result": "success"} return {"result": "success"}

@ -1,16 +1,34 @@
from typing import cast from typing import cast
import flask_login import flask_login
from flask import request from flask import redirect, request
from flask_restful import Resource, reqparse from flask_restful import Resource, reqparse
import services import services
from configs import dify_config
from constants.languages import languages
from controllers.console import api from controllers.console import api
from controllers.console.auth.error import (
EmailCodeError,
EmailOrPasswordMismatchError,
EmailPasswordLoginLimitError,
InvalidEmailError,
InvalidTokenError,
)
from controllers.console.error import (
AccountBannedError,
EmailSendIpLimitError,
NotAllowedCreateWorkspace,
NotAllowedRegister,
)
from controllers.console.setup import setup_required from controllers.console.setup import setup_required
from events.tenant_event import tenant_was_created
from libs.helper import email, extract_remote_ip from libs.helper import email, extract_remote_ip
from libs.password import valid_password from libs.password import valid_password
from models.account import Account from models.account import Account
from services.account_service import AccountService, TenantService from services.account_service import AccountService, RegisterService, TenantService
from services.errors.workspace import WorkSpaceNotAllowedCreateError
from services.feature_service import FeatureService
class LoginApi(Resource): class LoginApi(Resource):
@ -23,15 +41,43 @@ class LoginApi(Resource):
parser.add_argument("email", type=email, required=True, location="json") parser.add_argument("email", type=email, required=True, location="json")
parser.add_argument("password", type=valid_password, required=True, location="json") parser.add_argument("password", type=valid_password, required=True, location="json")
parser.add_argument("remember_me", type=bool, required=False, default=False, location="json") parser.add_argument("remember_me", type=bool, required=False, default=False, location="json")
parser.add_argument("invite_token", type=str, required=False, default=None, location="json")
parser.add_argument("language", type=str, required=False, default="en-US", location="json")
args = parser.parse_args() args = parser.parse_args()
# todo: Verify the recaptcha is_login_error_rate_limit = AccountService.is_login_error_rate_limit(args["email"])
if is_login_error_rate_limit:
raise EmailPasswordLoginLimitError()
invitation = args["invite_token"]
if invitation:
invitation = RegisterService.get_invitation_if_token_valid(None, args["email"], invitation)
if args["language"] is not None and args["language"] == "zh-Hans":
language = "zh-Hans"
else:
language = "en-US"
try: try:
if invitation:
data = invitation.get("data", {})
invitee_email = data.get("email") if data else None
if invitee_email != args["email"]:
raise InvalidEmailError()
account = AccountService.authenticate(args["email"], args["password"], args["invite_token"])
else:
account = AccountService.authenticate(args["email"], args["password"]) account = AccountService.authenticate(args["email"], args["password"])
except services.errors.account.AccountLoginError as e: except services.errors.account.AccountLoginError:
return {"code": "unauthorized", "message": str(e)}, 401 raise AccountBannedError()
except services.errors.account.AccountPasswordError:
AccountService.add_login_error_rate_limit(args["email"])
raise EmailOrPasswordMismatchError()
except services.errors.account.AccountNotFoundError:
if FeatureService.get_system_features().is_allow_register:
token = AccountService.send_reset_password_email(email=args["email"], language=language)
return {"result": "fail", "data": token, "code": "account_not_found"}
else:
raise NotAllowedRegister()
# SELF_HOSTED only have one workspace # SELF_HOSTED only have one workspace
tenants = TenantService.get_join_tenants(account) tenants = TenantService.get_join_tenants(account)
if len(tenants) == 0: if len(tenants) == 0:
@ -41,7 +87,7 @@ class LoginApi(Resource):
} }
token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request)) token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request))
AccountService.reset_login_error_rate_limit(args["email"])
return {"result": "success", "data": token_pair.model_dump()} return {"result": "success", "data": token_pair.model_dump()}
@ -49,60 +95,114 @@ class LogoutApi(Resource):
@setup_required @setup_required
def get(self): def get(self):
account = cast(Account, flask_login.current_user) account = cast(Account, flask_login.current_user)
if isinstance(account, flask_login.AnonymousUserMixin):
return {"result": "success"}
AccountService.logout(account=account) AccountService.logout(account=account)
flask_login.logout_user() flask_login.logout_user()
return {"result": "success"} return {"result": "success"}
class ResetPasswordApi(Resource): class ResetPasswordSendEmailApi(Resource):
@setup_required @setup_required
def get(self): def post(self):
# parser = reqparse.RequestParser() parser = reqparse.RequestParser()
# parser.add_argument('email', type=email, required=True, location='json') parser.add_argument("email", type=email, required=True, location="json")
# args = parser.parse_args() parser.add_argument("language", type=str, required=False, location="json")
args = parser.parse_args()
# import mailchimp_transactional as MailchimpTransactional
# from mailchimp_transactional.api_client import ApiClientError
# account = {'email': args['email']}
# account = AccountService.get_by_email(args['email'])
# if account is None:
# raise ValueError('Email not found')
# new_password = AccountService.generate_password()
# AccountService.update_password(account, new_password)
# todo: Send email
# MAILCHIMP_API_KEY = dify_config.MAILCHIMP_TRANSACTIONAL_API_KEY
# mailchimp = MailchimpTransactional(MAILCHIMP_API_KEY)
# message = {
# 'from_email': 'noreply@example.com',
# 'to': [{'email': account['email']}],
# 'subject': 'Reset your Dify password',
# 'html': """
# <p>Dear User,</p>
# <p>The Dify team has generated a new password for you, details as follows:</p>
# <p><strong>{new_password}</strong></p>
# <p>Please change your password to log in as soon as possible.</p>
# <p>Regards,</p>
# <p>The Dify Team</p>
# """
# }
# response = mailchimp.messages.send({
# 'message': message,
# # required for transactional email
# ' settings': {
# 'sandbox_mode': dify_config.MAILCHIMP_SANDBOX_MODE,
# },
# })
# Check if MSG was sent
# if response.status_code != 200:
# # handle error
# pass
return {"result": "success"} if args["language"] is not None and args["language"] == "zh-Hans":
language = "zh-Hans"
else:
language = "en-US"
account = AccountService.get_user_through_email(args["email"])
if account is None:
if FeatureService.get_system_features().is_allow_register:
token = AccountService.send_reset_password_email(email=args["email"], language=language)
else:
raise NotAllowedRegister()
else:
token = AccountService.send_reset_password_email(account=account, language=language)
return {"result": "success", "data": token}
class EmailCodeLoginSendEmailApi(Resource):
@setup_required
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("email", type=email, required=True, location="json")
parser.add_argument("language", type=str, required=False, location="json")
args = parser.parse_args()
ip_address = extract_remote_ip(request)
if AccountService.is_email_send_ip_limit(ip_address):
raise EmailSendIpLimitError()
if args["language"] is not None and args["language"] == "zh-Hans":
language = "zh-Hans"
else:
language = "en-US"
account = AccountService.get_user_through_email(args["email"])
if account is None:
if FeatureService.get_system_features().is_allow_register:
token = AccountService.send_email_code_login_email(email=args["email"], language=language)
else:
raise NotAllowedRegister()
else:
token = AccountService.send_email_code_login_email(account=account, language=language)
return {"result": "success", "data": token}
class EmailCodeLoginApi(Resource):
@setup_required
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("email", type=str, required=True, location="json")
parser.add_argument("code", type=str, required=True, location="json")
parser.add_argument("token", type=str, required=True, location="json")
args = parser.parse_args()
user_email = args["email"]
token_data = AccountService.get_email_code_login_data(args["token"])
if token_data is None:
raise InvalidTokenError()
if token_data["email"] != args["email"]:
raise InvalidEmailError()
if token_data["code"] != args["code"]:
raise EmailCodeError()
AccountService.revoke_email_code_login_token(args["token"])
account = AccountService.get_user_through_email(user_email)
if account:
tenant = TenantService.get_join_tenants(account)
if not tenant:
if not FeatureService.get_system_features().is_allow_create_workspace:
raise NotAllowedCreateWorkspace()
else:
tenant = TenantService.create_tenant(f"{account.name}'s Workspace")
TenantService.create_tenant_member(tenant, account, role="owner")
account.current_tenant = tenant
tenant_was_created.send(tenant)
if account is None:
try:
account = AccountService.create_account_and_tenant(
email=user_email, name=user_email, interface_language=languages[0]
)
except WorkSpaceNotAllowedCreateError:
return redirect(
f"{dify_config.CONSOLE_WEB_URL}/signin"
"?message=Workspace not found, please contact system admin to invite you to join in a workspace."
)
token_pair = AccountService.login(account, ip_address=extract_remote_ip(request))
AccountService.reset_login_error_rate_limit(args["email"])
return {"result": "success", "data": token_pair.model_dump()}
class RefreshTokenApi(Resource): class RefreshTokenApi(Resource):
@ -120,4 +220,7 @@ class RefreshTokenApi(Resource):
api.add_resource(LoginApi, "/login") api.add_resource(LoginApi, "/login")
api.add_resource(LogoutApi, "/logout") api.add_resource(LogoutApi, "/logout")
api.add_resource(EmailCodeLoginSendEmailApi, "/email-code-login")
api.add_resource(EmailCodeLoginApi, "/email-code-login/validity")
api.add_resource(ResetPasswordSendEmailApi, "/reset-password")
api.add_resource(RefreshTokenApi, "/refresh-token") api.add_resource(RefreshTokenApi, "/refresh-token")

@ -5,14 +5,20 @@ from typing import Optional
import requests import requests
from flask import current_app, redirect, request from flask import current_app, redirect, request
from flask_restful import Resource from flask_restful import Resource
from werkzeug.exceptions import Unauthorized
from configs import dify_config from configs import dify_config
from constants.languages import languages from constants.languages import languages
from events.tenant_event import tenant_was_created
from extensions.ext_database import db from extensions.ext_database import db
from libs.helper import extract_remote_ip from libs.helper import extract_remote_ip
from libs.oauth import GitHubOAuth, GoogleOAuth, OAuthUserInfo from libs.oauth import GitHubOAuth, GoogleOAuth, OAuthUserInfo
from models.account import Account, AccountStatus from models import Account
from models.account import AccountStatus
from services.account_service import AccountService, RegisterService, TenantService from services.account_service import AccountService, RegisterService, TenantService
from services.errors.account import AccountNotFoundError
from services.errors.workspace import WorkSpaceNotAllowedCreateError, WorkSpaceNotFoundError
from services.feature_service import FeatureService
from .. import api from .. import api
@ -42,6 +48,7 @@ def get_oauth_providers():
class OAuthLogin(Resource): class OAuthLogin(Resource):
def get(self, provider: str): def get(self, provider: str):
invite_token = request.args.get("invite_token") or None
OAUTH_PROVIDERS = get_oauth_providers() OAUTH_PROVIDERS = get_oauth_providers()
with current_app.app_context(): with current_app.app_context():
oauth_provider = OAUTH_PROVIDERS.get(provider) oauth_provider = OAUTH_PROVIDERS.get(provider)
@ -49,7 +56,7 @@ class OAuthLogin(Resource):
if not oauth_provider: if not oauth_provider:
return {"error": "Invalid provider"}, 400 return {"error": "Invalid provider"}, 400
auth_url = oauth_provider.get_authorization_url() auth_url = oauth_provider.get_authorization_url(invite_token=invite_token)
return redirect(auth_url) return redirect(auth_url)
@ -62,6 +69,11 @@ class OAuthCallback(Resource):
return {"error": "Invalid provider"}, 400 return {"error": "Invalid provider"}, 400
code = request.args.get("code") code = request.args.get("code")
state = request.args.get("state")
invite_token = None
if state:
invite_token = state
try: try:
token = oauth_provider.get_access_token(code) token = oauth_provider.get_access_token(code)
user_info = oauth_provider.get_user_info(token) user_info = oauth_provider.get_user_info(token)
@ -69,7 +81,27 @@ class OAuthCallback(Resource):
logging.exception(f"An error occurred during the OAuth process with {provider}: {e.response.text}") logging.exception(f"An error occurred during the OAuth process with {provider}: {e.response.text}")
return {"error": "OAuth process failed"}, 400 return {"error": "OAuth process failed"}, 400
if invite_token and RegisterService.is_valid_invite_token(invite_token):
invitation = RegisterService._get_invitation_by_token(token=invite_token)
if invitation:
invitation_email = invitation.get("email", None)
if invitation_email != user_info.email:
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message=Invalid invitation token.")
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin/invite-settings?invite_token={invite_token}")
try:
account = _generate_account(provider, user_info) account = _generate_account(provider, user_info)
except AccountNotFoundError:
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message=Account not found.")
except WorkSpaceNotFoundError:
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message=Workspace not found.")
except WorkSpaceNotAllowedCreateError:
return redirect(
f"{dify_config.CONSOLE_WEB_URL}/signin"
"?message=Workspace not found, please contact system admin to invite you to join in a workspace."
)
# Check account status # Check account status
if account.status in {AccountStatus.BANNED.value, AccountStatus.CLOSED.value}: if account.status in {AccountStatus.BANNED.value, AccountStatus.CLOSED.value}:
return {"error": "Account is banned or closed."}, 403 return {"error": "Account is banned or closed."}, 403
@ -79,7 +111,15 @@ class OAuthCallback(Resource):
account.initialized_at = datetime.now(timezone.utc).replace(tzinfo=None) account.initialized_at = datetime.now(timezone.utc).replace(tzinfo=None)
db.session.commit() db.session.commit()
try:
TenantService.create_owner_tenant_if_not_exist(account) TenantService.create_owner_tenant_if_not_exist(account)
except Unauthorized:
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message=Workspace not found.")
except WorkSpaceNotAllowedCreateError:
return redirect(
f"{dify_config.CONSOLE_WEB_URL}/signin"
"?message=Workspace not found, please contact system admin to invite you to join in a workspace."
)
token_pair = AccountService.login( token_pair = AccountService.login(
account=account, account=account,
@ -104,8 +144,20 @@ def _generate_account(provider: str, user_info: OAuthUserInfo):
# Get account by openid or email. # Get account by openid or email.
account = _get_account_by_openid_or_email(provider, user_info) account = _get_account_by_openid_or_email(provider, user_info)
if account:
tenant = TenantService.get_join_tenants(account)
if not tenant:
if not FeatureService.get_system_features().is_allow_create_workspace:
raise WorkSpaceNotAllowedCreateError()
else:
tenant = TenantService.create_tenant(f"{account.name}'s Workspace")
TenantService.create_tenant_member(tenant, account, role="owner")
account.current_tenant = tenant
tenant_was_created.send(tenant)
if not account: if not account:
# Create account if not FeatureService.get_system_features().is_allow_register:
raise AccountNotFoundError()
account_name = user_info.name or "Dify" account_name = user_info.name or "Dify"
account = RegisterService.register( account = RegisterService.register(
email=user_info.email, name=account_name, password=None, open_id=user_info.id, provider=provider email=user_info.email, name=account_name, password=None, open_id=user_info.id, provider=provider

@ -15,8 +15,7 @@ from core.rag.extractor.notion_extractor import NotionExtractor
from extensions.ext_database import db from extensions.ext_database import db
from fields.data_source_fields import integrate_list_fields, integrate_notion_info_list_fields from fields.data_source_fields import integrate_list_fields, integrate_notion_info_list_fields
from libs.login import login_required from libs.login import login_required
from models.dataset import Document from models import DataSourceOauthBinding, Document
from models.source import DataSourceOauthBinding
from services.dataset_service import DatasetService, DocumentService from services.dataset_service import DatasetService, DocumentService
from tasks.document_indexing_sync_task import document_indexing_sync_task from tasks.document_indexing_sync_task import document_indexing_sync_task

@ -24,8 +24,8 @@ from fields.app_fields import related_app_list
from fields.dataset_fields import dataset_detail_fields, dataset_query_detail_fields from fields.dataset_fields import dataset_detail_fields, dataset_query_detail_fields
from fields.document_fields import document_status_fields from fields.document_fields import document_status_fields
from libs.login import login_required from libs.login import login_required
from models.dataset import Dataset, DatasetPermissionEnum, Document, DocumentSegment from models import ApiToken, Dataset, Document, DocumentSegment, UploadFile
from models.model import ApiToken, UploadFile from models.dataset import DatasetPermissionEnum
from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService

@ -45,8 +45,7 @@ from fields.document_fields import (
document_with_segments_fields, document_with_segments_fields,
) )
from libs.login import login_required from libs.login import login_required
from models.dataset import Dataset, DatasetProcessRule, Document, DocumentSegment from models import Dataset, DatasetProcessRule, Document, DocumentSegment, UploadFile
from models.model import UploadFile
from services.dataset_service import DatasetService, DocumentService from services.dataset_service import DatasetService, DocumentService
from tasks.add_document_to_index_task import add_document_to_index_task from tasks.add_document_to_index_task import add_document_to_index_task
from tasks.remove_document_from_index_task import remove_document_from_index_task from tasks.remove_document_from_index_task import remove_document_from_index_task

@ -24,7 +24,7 @@ from extensions.ext_database import db
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client
from fields.segment_fields import segment_fields from fields.segment_fields import segment_fields
from libs.login import login_required from libs.login import login_required
from models.dataset import DocumentSegment from models import DocumentSegment
from services.dataset_service import DatasetService, DocumentService, SegmentService from services.dataset_service import DatasetService, DocumentService, SegmentService
from tasks.batch_create_segment_to_index_task import batch_create_segment_to_index_task from tasks.batch_create_segment_to_index_task import batch_create_segment_to_index_task
from tasks.disable_segment_from_index_task import disable_segment_from_index_task from tasks.disable_segment_from_index_task import disable_segment_from_index_task

@ -1,9 +1,12 @@
import urllib.parse
from flask import request from flask import request
from flask_login import current_user from flask_login import current_user
from flask_restful import Resource, marshal_with from flask_restful import Resource, marshal_with, reqparse
import services import services
from configs import dify_config from configs import dify_config
from constants import DOCUMENT_EXTENSIONS
from controllers.console import api from controllers.console import api
from controllers.console.datasets.error import ( from controllers.console.datasets.error import (
FileTooLargeError, FileTooLargeError,
@ -13,9 +16,10 @@ from controllers.console.datasets.error import (
) )
from controllers.console.setup import setup_required from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
from fields.file_fields import file_fields, upload_config_fields from core.helper import ssrf_proxy
from fields.file_fields import file_fields, remote_file_info_fields, upload_config_fields
from libs.login import login_required from libs.login import login_required
from services.file_service import ALLOWED_EXTENSIONS, UNSTRUCTURED_ALLOWED_EXTENSIONS, FileService from services.file_service import FileService
PREVIEW_WORDS_LIMIT = 3000 PREVIEW_WORDS_LIMIT = 3000
@ -44,6 +48,10 @@ class FileApi(Resource):
# get file from request # get file from request
file = request.files["file"] file = request.files["file"]
parser = reqparse.RequestParser()
parser.add_argument("source", type=str, required=False, location="args")
source = parser.parse_args().get("source")
# check file # check file
if "file" not in request.files: if "file" not in request.files:
raise NoFileUploadedError() raise NoFileUploadedError()
@ -51,7 +59,7 @@ class FileApi(Resource):
if len(request.files) > 1: if len(request.files) > 1:
raise TooManyFilesError() raise TooManyFilesError()
try: try:
upload_file = FileService.upload_file(file, current_user) upload_file = FileService.upload_file(file=file, user=current_user, source=source)
except services.errors.file.FileTooLargeError as file_too_large_error: except services.errors.file.FileTooLargeError as file_too_large_error:
raise FileTooLargeError(file_too_large_error.description) raise FileTooLargeError(file_too_large_error.description)
except services.errors.file.UnsupportedFileTypeError: except services.errors.file.UnsupportedFileTypeError:
@ -75,11 +83,24 @@ class FileSupportTypeApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self): def get(self):
etl_type = dify_config.ETL_TYPE return {"allowed_extensions": DOCUMENT_EXTENSIONS}
allowed_extensions = UNSTRUCTURED_ALLOWED_EXTENSIONS if etl_type == "Unstructured" else ALLOWED_EXTENSIONS
return {"allowed_extensions": allowed_extensions}
class RemoteFileInfoApi(Resource):
@marshal_with(remote_file_info_fields)
def get(self, url):
decoded_url = urllib.parse.unquote(url)
try:
response = ssrf_proxy.head(decoded_url)
return {
"file_type": response.headers.get("Content-Type", "application/octet-stream"),
"file_length": int(response.headers.get("Content-Length", 0)),
}
except Exception as e:
return {"error": str(e)}, 400
api.add_resource(FileApi, "/files/upload") api.add_resource(FileApi, "/files/upload")
api.add_resource(FilePreviewApi, "/files/<uuid:file_id>/preview") api.add_resource(FilePreviewApi, "/files/<uuid:file_id>/preview")
api.add_resource(FileSupportTypeApi, "/files/support-type") api.add_resource(FileSupportTypeApi, "/files/support-type")
api.add_resource(RemoteFileInfoApi, "/remote-files/<path:url>")

@ -38,3 +38,27 @@ class AlreadyActivateError(BaseHTTPException):
error_code = "already_activate" error_code = "already_activate"
description = "Auth Token is invalid or account already activated, please check again." description = "Auth Token is invalid or account already activated, please check again."
code = 403 code = 403
class NotAllowedCreateWorkspace(BaseHTTPException):
error_code = "unauthorized"
description = "Workspace not found, please contact system admin to invite you to join in a workspace."
code = 400
class AccountBannedError(BaseHTTPException):
error_code = "account_banned"
description = "Account is banned."
code = 400
class NotAllowedRegister(BaseHTTPException):
error_code = "unauthorized"
description = "Account not found."
code = 400
class EmailSendIpLimitError(BaseHTTPException):
error_code = "email_send_ip_limit"
description = "Too many emails have been sent from this IP address recently. Please try again later."
code = 429

@ -11,7 +11,7 @@ from controllers.console.wraps import account_initialization_required, cloud_edi
from extensions.ext_database import db from extensions.ext_database import db
from fields.installed_app_fields import installed_app_list_fields from fields.installed_app_fields import installed_app_list_fields
from libs.login import login_required from libs.login import login_required
from models.model import App, InstalledApp, RecommendedApp from models import App, InstalledApp, RecommendedApp
from services.account_service import TenantService from services.account_service import TenantService

@ -18,7 +18,7 @@ message_fields = {
"inputs": fields.Raw, "inputs": fields.Raw,
"query": fields.String, "query": fields.String,
"answer": fields.String, "answer": fields.String,
"message_files": fields.List(fields.Nested(message_file_fields), attribute="files"), "message_files": fields.List(fields.Nested(message_file_fields)),
"feedback": fields.Nested(feedback_fields, attribute="user_feedback", allow_null=True), "feedback": fields.Nested(feedback_fields, attribute="user_feedback", allow_null=True),
"created_at": TimestampField, "created_at": TimestampField,
} }

@ -7,7 +7,7 @@ from werkzeug.exceptions import NotFound
from controllers.console.wraps import account_initialization_required from controllers.console.wraps import account_initialization_required
from extensions.ext_database import db from extensions.ext_database import db
from libs.login import login_required from libs.login import login_required
from models.model import InstalledApp from models import InstalledApp
def installed_app_required(view=None): def installed_app_required(view=None):

@ -20,7 +20,7 @@ from extensions.ext_database import db
from fields.member_fields import account_fields from fields.member_fields import account_fields
from libs.helper import TimestampField, timezone from libs.helper import TimestampField, timezone
from libs.login import login_required from libs.login import login_required
from models.account import AccountIntegrate, InvitationCode from models import AccountIntegrate, InvitationCode
from services.account_service import AccountService from services.account_service import AccountService
from services.errors.account import CurrentPasswordIncorrectError as ServiceCurrentPasswordIncorrectError from services.errors.account import CurrentPasswordIncorrectError as ServiceCurrentPasswordIncorrectError

@ -397,16 +397,15 @@ class ToolWorkflowProviderCreateApi(Resource):
args = reqparser.parse_args() args = reqparser.parse_args()
return WorkflowToolManageService.create_workflow_tool( return WorkflowToolManageService.create_workflow_tool(
user_id, user_id=user_id,
tenant_id, tenant_id=tenant_id,
args["workflow_app_id"], workflow_app_id=args["workflow_app_id"],
args["name"], name=args["name"],
args["label"], label=args["label"],
args["icon"], icon=args["icon"],
args["description"], description=args["description"],
args["parameters"], parameters=args["parameters"],
args["privacy_policy"], privacy_policy=args["privacy_policy"],
args.get("labels", []),
) )

@ -198,7 +198,7 @@ class WebappLogoWorkspaceApi(Resource):
raise UnsupportedFileTypeError() raise UnsupportedFileTypeError()
try: try:
upload_file = FileService.upload_file(file, current_user, True) upload_file = FileService.upload_file(file=file, user=current_user)
except services.errors.file.FileTooLargeError as file_too_large_error: except services.errors.file.FileTooLargeError as file_too_large_error:
raise FileTooLargeError(file_too_large_error.description) raise FileTooLargeError(file_too_large_error.description)

@ -10,6 +10,34 @@ from services.file_service import FileService
class ImagePreviewApi(Resource): class ImagePreviewApi(Resource):
"""
Deprecated
"""
def get(self, file_id):
file_id = str(file_id)
timestamp = request.args.get("timestamp")
nonce = request.args.get("nonce")
sign = request.args.get("sign")
if not timestamp or not nonce or not sign:
return {"content": "Invalid request."}, 400
try:
generator, mimetype = FileService.get_image_preview(
file_id=file_id,
timestamp=timestamp,
nonce=nonce,
sign=sign,
)
except services.errors.file.UnsupportedFileTypeError:
raise UnsupportedFileTypeError()
return Response(generator, mimetype=mimetype)
class FilePreviewApi(Resource):
def get(self, file_id): def get(self, file_id):
file_id = str(file_id) file_id = str(file_id)
@ -21,7 +49,12 @@ class ImagePreviewApi(Resource):
return {"content": "Invalid request."}, 400 return {"content": "Invalid request."}, 400
try: try:
generator, mimetype = FileService.get_image_preview(file_id, timestamp, nonce, sign) generator, mimetype = FileService.get_signed_file_preview(
file_id=file_id,
timestamp=timestamp,
nonce=nonce,
sign=sign,
)
except services.errors.file.UnsupportedFileTypeError: except services.errors.file.UnsupportedFileTypeError:
raise UnsupportedFileTypeError() raise UnsupportedFileTypeError()
@ -49,4 +82,5 @@ class WorkspaceWebappLogoApi(Resource):
api.add_resource(ImagePreviewApi, "/files/<uuid:file_id>/image-preview") api.add_resource(ImagePreviewApi, "/files/<uuid:file_id>/image-preview")
api.add_resource(FilePreviewApi, "/files/<uuid:file_id>/file-preview")
api.add_resource(WorkspaceWebappLogoApi, "/files/workspaces/<uuid:workspace_id>/webapp-logo") api.add_resource(WorkspaceWebappLogoApi, "/files/workspaces/<uuid:workspace_id>/webapp-logo")

@ -16,6 +16,7 @@ class ToolFilePreviewApi(Resource):
parser.add_argument("timestamp", type=str, required=True, location="args") parser.add_argument("timestamp", type=str, required=True, location="args")
parser.add_argument("nonce", type=str, required=True, location="args") parser.add_argument("nonce", type=str, required=True, location="args")
parser.add_argument("sign", type=str, required=True, location="args") parser.add_argument("sign", type=str, required=True, location="args")
parser.add_argument("as_attachment", type=bool, required=False, default=False, location="args")
args = parser.parse_args() args = parser.parse_args()
@ -28,18 +29,27 @@ class ToolFilePreviewApi(Resource):
raise Forbidden("Invalid request.") raise Forbidden("Invalid request.")
try: try:
result = ToolFileManager.get_file_generator_by_tool_file_id( stream, tool_file = ToolFileManager.get_file_generator_by_tool_file_id(
file_id, file_id,
) )
if not result: if not stream or not tool_file:
raise NotFound("file is not found") raise NotFound("file is not found")
generator, mimetype = result
except Exception: except Exception:
raise UnsupportedFileTypeError() raise UnsupportedFileTypeError()
return Response(generator, mimetype=mimetype) response = Response(
stream,
mimetype=tool_file.mimetype,
direct_passthrough=True,
headers={
"Content-Length": str(tool_file.size),
},
)
if args["as_attachment"]:
response.headers["Content-Disposition"] = f"attachment; filename={tool_file.name}"
return response
api.add_resource(ToolFilePreviewApi, "/files/tools/<uuid:file_id>.<string:extension>") api.add_resource(ToolFilePreviewApi, "/files/tools/<uuid:file_id>.<string:extension>")

@ -48,7 +48,7 @@ class MessageListApi(Resource):
"tool_input": fields.String, "tool_input": fields.String,
"created_at": TimestampField, "created_at": TimestampField,
"observation": fields.String, "observation": fields.String,
"message_files": fields.List(fields.String, attribute="files"), "message_files": fields.List(fields.String),
} }
message_fields = { message_fields = {
@ -58,7 +58,7 @@ class MessageListApi(Resource):
"inputs": fields.Raw, "inputs": fields.Raw,
"query": fields.String, "query": fields.String,
"answer": fields.String(attribute="re_sign_file_url_answer"), "answer": fields.String(attribute="re_sign_file_url_answer"),
"message_files": fields.List(fields.Nested(message_file_fields), attribute="files"), "message_files": fields.List(fields.Nested(message_file_fields)),
"feedback": fields.Nested(feedback_fields, attribute="user_feedback", allow_null=True), "feedback": fields.Nested(feedback_fields, attribute="user_feedback", allow_null=True),
"retriever_resources": fields.List(fields.Nested(retriever_resource_fields)), "retriever_resources": fields.List(fields.Nested(retriever_resource_fields)),
"created_at": TimestampField, "created_at": TimestampField,

@ -1,11 +1,14 @@
import urllib.parse
from flask import request from flask import request
from flask_restful import marshal_with from flask_restful import marshal_with, reqparse
import services import services
from controllers.web import api from controllers.web import api
from controllers.web.error import FileTooLargeError, NoFileUploadedError, TooManyFilesError, UnsupportedFileTypeError from controllers.web.error import FileTooLargeError, NoFileUploadedError, TooManyFilesError, UnsupportedFileTypeError
from controllers.web.wraps import WebApiResource from controllers.web.wraps import WebApiResource
from fields.file_fields import file_fields from core.helper import ssrf_proxy
from fields.file_fields import file_fields, remote_file_info_fields
from services.file_service import FileService from services.file_service import FileService
@ -15,6 +18,10 @@ class FileApi(WebApiResource):
# get file from request # get file from request
file = request.files["file"] file = request.files["file"]
parser = reqparse.RequestParser()
parser.add_argument("source", type=str, required=False, location="args")
source = parser.parse_args().get("source")
# check file # check file
if "file" not in request.files: if "file" not in request.files:
raise NoFileUploadedError() raise NoFileUploadedError()
@ -22,7 +29,7 @@ class FileApi(WebApiResource):
if len(request.files) > 1: if len(request.files) > 1:
raise TooManyFilesError() raise TooManyFilesError()
try: try:
upload_file = FileService.upload_file(file, end_user) upload_file = FileService.upload_file(file=file, user=end_user, source=source)
except services.errors.file.FileTooLargeError as file_too_large_error: except services.errors.file.FileTooLargeError as file_too_large_error:
raise FileTooLargeError(file_too_large_error.description) raise FileTooLargeError(file_too_large_error.description)
except services.errors.file.UnsupportedFileTypeError: except services.errors.file.UnsupportedFileTypeError:
@ -31,4 +38,19 @@ class FileApi(WebApiResource):
return upload_file, 201 return upload_file, 201
class RemoteFileInfoApi(WebApiResource):
@marshal_with(remote_file_info_fields)
def get(self, url):
decoded_url = urllib.parse.unquote(url)
try:
response = ssrf_proxy.head(decoded_url)
return {
"file_type": response.headers.get("Content-Type", "application/octet-stream"),
"file_length": int(response.headers.get("Content-Length", 0)),
}
except Exception as e:
return {"error": str(e)}, 400
api.add_resource(FileApi, "/files/upload") api.add_resource(FileApi, "/files/upload")
api.add_resource(RemoteFileInfoApi, "/remote-files/<path:url>")

@ -22,6 +22,7 @@ from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotIni
from core.model_runtime.errors.invoke import InvokeError from core.model_runtime.errors.invoke import InvokeError
from fields.conversation_fields import message_file_fields from fields.conversation_fields import message_file_fields
from fields.message_fields import agent_thought_fields from fields.message_fields import agent_thought_fields
from fields.raws import FilesContainedField
from libs import helper from libs import helper
from libs.helper import TimestampField, uuid_value from libs.helper import TimestampField, uuid_value
from models.model import AppMode from models.model import AppMode
@ -58,10 +59,10 @@ class MessageListApi(WebApiResource):
"id": fields.String, "id": fields.String,
"conversation_id": fields.String, "conversation_id": fields.String,
"parent_message_id": fields.String, "parent_message_id": fields.String,
"inputs": fields.Raw, "inputs": FilesContainedField,
"query": fields.String, "query": fields.String,
"answer": fields.String(attribute="re_sign_file_url_answer"), "answer": fields.String(attribute="re_sign_file_url_answer"),
"message_files": fields.List(fields.Nested(message_file_fields), attribute="files"), "message_files": fields.List(fields.Nested(message_file_fields)),
"feedback": fields.Nested(feedback_fields, attribute="user_feedback", allow_null=True), "feedback": fields.Nested(feedback_fields, attribute="user_feedback", allow_null=True),
"retriever_resources": fields.List(fields.Nested(retriever_resource_fields)), "retriever_resources": fields.List(fields.Nested(retriever_resource_fields)),
"created_at": TimestampField, "created_at": TimestampField,

@ -17,7 +17,7 @@ message_fields = {
"inputs": fields.Raw, "inputs": fields.Raw,
"query": fields.String, "query": fields.String,
"answer": fields.String, "answer": fields.String,
"message_files": fields.List(fields.Nested(message_file_fields), attribute="files"), "message_files": fields.List(fields.Nested(message_file_fields)),
"feedback": fields.Nested(feedback_fields, attribute="user_feedback", allow_null=True), "feedback": fields.Nested(feedback_fields, attribute="user_feedback", allow_null=True),
"created_at": TimestampField, "created_at": TimestampField,
} }

@ -15,12 +15,12 @@ from core.app.entities.app_invoke_entities import (
) )
from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
from core.file.message_file_parser import MessageFileParser from core.file import file_manager
from core.memory.token_buffer_memory import TokenBufferMemory from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance from core.model_manager import ModelInstance
from core.model_runtime.entities.llm_entities import LLMUsage from core.model_runtime.entities import (
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage, AssistantPromptMessage,
LLMUsage,
PromptMessage, PromptMessage,
PromptMessageContent, PromptMessageContent,
PromptMessageTool, PromptMessageTool,
@ -38,9 +38,9 @@ from core.tools.entities.tool_entities import (
) )
from core.tools.tool_manager import ToolManager from core.tools.tool_manager import ToolManager
from core.tools.utils.dataset_retriever_tool import DatasetRetrieverTool from core.tools.utils.dataset_retriever_tool import DatasetRetrieverTool
from core.tools.utils.tool_parameter_converter import ToolParameterConverter
from extensions.ext_database import db from extensions.ext_database import db
from models.model import Conversation, Message, MessageAgentThought from factories import file_factory
from models.model import Conversation, Message, MessageAgentThought, MessageFile
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -61,23 +61,6 @@ class BaseAgentRunner(AppRunner):
memory: Optional[TokenBufferMemory] = None, memory: Optional[TokenBufferMemory] = None,
prompt_messages: Optional[list[PromptMessage]] = None, prompt_messages: Optional[list[PromptMessage]] = None,
) -> None: ) -> None:
"""
Agent runner
:param tenant_id: tenant id
:param application_generate_entity: application generate entity
:param conversation: conversation
:param app_config: app generate entity
:param model_config: model config
:param config: dataset config
:param queue_manager: queue manager
:param message: message
:param user_id: user id
:param memory: memory
:param prompt_messages: prompt messages
:param variables_pool: variables pool
:param db_variables: db variables
:param model_instance: model instance
"""
self.tenant_id = tenant_id self.tenant_id = tenant_id
self.application_generate_entity = application_generate_entity self.application_generate_entity = application_generate_entity
self.conversation = conversation self.conversation = conversation
@ -172,7 +155,7 @@ class BaseAgentRunner(AppRunner):
if parameter.form != ToolParameter.ToolParameterForm.LLM: if parameter.form != ToolParameter.ToolParameterForm.LLM:
continue continue
parameter_type = ToolParameterConverter.get_parameter_type(parameter.type) parameter_type = parameter.type.as_normal_type()
enum = [] enum = []
if parameter.type == ToolParameter.ToolParameterType.SELECT: if parameter.type == ToolParameter.ToolParameterType.SELECT:
enum = [option.value for option in parameter.options] enum = [option.value for option in parameter.options]
@ -259,7 +242,7 @@ class BaseAgentRunner(AppRunner):
if parameter.form != ToolParameter.ToolParameterForm.LLM: if parameter.form != ToolParameter.ToolParameterForm.LLM:
continue continue
parameter_type = ToolParameterConverter.get_parameter_type(parameter.type) parameter_type = parameter.type.as_normal_type()
enum = [] enum = []
if parameter.type == ToolParameter.ToolParameterType.SELECT: if parameter.type == ToolParameter.ToolParameterType.SELECT:
enum = [option.value for option in parameter.options] enum = [option.value for option in parameter.options]
@ -492,18 +475,15 @@ class BaseAgentRunner(AppRunner):
return result return result
def organize_agent_user_prompt(self, message: Message) -> UserPromptMessage: def organize_agent_user_prompt(self, message: Message) -> UserPromptMessage:
message_file_parser = MessageFileParser( files = db.session.query(MessageFile).filter(MessageFile.message_id == message.id).all()
tenant_id=self.tenant_id,
app_id=self.app_config.app_id,
)
files = message.message_files
if files: if files:
assert message.app_model_config assert message.app_model_config
file_extra_config = FileUploadConfigManager.convert(message.app_model_config.to_dict()) file_extra_config = FileUploadConfigManager.convert(message.app_model_config.to_dict())
if file_extra_config: if file_extra_config:
file_objs = message_file_parser.transform_message_files(files, file_extra_config) file_objs = file_factory.build_from_message_files(
message_files=files, tenant_id=self.tenant_id, config=file_extra_config
)
else: else:
file_objs = [] file_objs = []
@ -512,7 +492,7 @@ class BaseAgentRunner(AppRunner):
else: else:
prompt_message_contents: list[PromptMessageContent] = [TextPromptMessageContent(data=message.query)] prompt_message_contents: list[PromptMessageContent] = [TextPromptMessageContent(data=message.query)]
for file_obj in file_objs: for file_obj in file_objs:
prompt_message_contents.append(file_obj.prompt_message_content) prompt_message_contents.append(file_manager.to_prompt_message_content(file_obj))
return UserPromptMessage(content=prompt_message_contents) return UserPromptMessage(content=prompt_message_contents)
else: else:

@ -1,7 +1,8 @@
import json import json
from core.agent.cot_agent_runner import CotAgentRunner from core.agent.cot_agent_runner import CotAgentRunner
from core.model_runtime.entities.message_entities import ( from core.file import file_manager
from core.model_runtime.entities import (
AssistantPromptMessage, AssistantPromptMessage,
PromptMessage, PromptMessage,
PromptMessageContent, PromptMessageContent,
@ -38,7 +39,7 @@ class CotChatAgentRunner(CotAgentRunner):
if self.files: if self.files:
prompt_message_contents: list[PromptMessageContent] = [TextPromptMessageContent(data=query)] prompt_message_contents: list[PromptMessageContent] = [TextPromptMessageContent(data=query)]
for file_obj in self.files: for file_obj in self.files:
prompt_message_contents.append(file_obj.prompt_message_content) prompt_message_contents.append(file_manager.to_prompt_message_content(file_obj))
prompt_messages.append(UserPromptMessage(content=prompt_message_contents)) prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
else: else:

@ -7,9 +7,13 @@ from typing import Any, Optional, Union
from core.agent.base_agent_runner import BaseAgentRunner from core.agent.base_agent_runner import BaseAgentRunner
from core.app.apps.base_app_queue_manager import PublishFrom from core.app.apps.base_app_queue_manager import PublishFrom
from core.app.entities.queue_entities import QueueAgentThoughtEvent, QueueMessageEndEvent, QueueMessageFileEvent from core.app.entities.queue_entities import QueueAgentThoughtEvent, QueueMessageEndEvent, QueueMessageFileEvent
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage from core.file import file_manager
from core.model_runtime.entities.message_entities import ( from core.model_runtime.entities import (
AssistantPromptMessage, AssistantPromptMessage,
LLMResult,
LLMResultChunk,
LLMResultChunkDelta,
LLMUsage,
PromptMessage, PromptMessage,
PromptMessageContent, PromptMessageContent,
PromptMessageContentType, PromptMessageContentType,
@ -390,7 +394,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
if self.files: if self.files:
prompt_message_contents: list[PromptMessageContent] = [TextPromptMessageContent(data=query)] prompt_message_contents: list[PromptMessageContent] = [TextPromptMessageContent(data=query)]
for file_obj in self.files: for file_obj in self.files:
prompt_message_contents.append(file_obj.prompt_message_content) prompt_message_contents.append(file_manager.to_prompt_message_content(file_obj))
prompt_messages.append(UserPromptMessage(content=prompt_message_contents)) prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
else: else:

@ -53,12 +53,11 @@ class BasicVariablesConfigManager:
VariableEntity( VariableEntity(
type=variable_type, type=variable_type,
variable=variable.get("variable"), variable=variable.get("variable"),
description=variable.get("description"), description=variable.get("description", ""),
label=variable.get("label"), label=variable.get("label"),
required=variable.get("required", False), required=variable.get("required", False),
max_length=variable.get("max_length"), max_length=variable.get("max_length"),
options=variable.get("options"), options=variable.get("options", []),
default=variable.get("default"),
) )
) )

@ -1,11 +1,12 @@
from collections.abc import Sequence
from enum import Enum from enum import Enum
from typing import Any, Optional from typing import Any, Optional
from pydantic import BaseModel from pydantic import BaseModel, Field
from core.file.file_obj import FileExtraConfig from core.file import FileExtraConfig, FileTransferMethod, FileType
from core.model_runtime.entities.message_entities import PromptMessageRole from core.model_runtime.entities.message_entities import PromptMessageRole
from models import AppMode from models.model import AppMode
class ModelConfigEntity(BaseModel): class ModelConfigEntity(BaseModel):
@ -69,7 +70,7 @@ class PromptTemplateEntity(BaseModel):
ADVANCED = "advanced" ADVANCED = "advanced"
@classmethod @classmethod
def value_of(cls, value: str) -> "PromptType": def value_of(cls, value: str):
""" """
Get value of given mode. Get value of given mode.
@ -93,6 +94,8 @@ class VariableEntityType(str, Enum):
PARAGRAPH = "paragraph" PARAGRAPH = "paragraph"
NUMBER = "number" NUMBER = "number"
EXTERNAL_DATA_TOOL = "external_data_tool" EXTERNAL_DATA_TOOL = "external_data_tool"
FILE = "file"
FILE_LIST = "file-list"
class VariableEntity(BaseModel): class VariableEntity(BaseModel):
@ -102,13 +105,14 @@ class VariableEntity(BaseModel):
variable: str variable: str
label: str label: str
description: Optional[str] = None description: str = ""
type: VariableEntityType type: VariableEntityType
required: bool = False required: bool = False
max_length: Optional[int] = None max_length: Optional[int] = None
options: Optional[list[str]] = None options: Sequence[str] = Field(default_factory=list)
default: Optional[str] = None allowed_file_types: Sequence[FileType] = Field(default_factory=list)
hint: Optional[str] = None allowed_file_extensions: Sequence[str] = Field(default_factory=list)
allowed_file_upload_methods: Sequence[FileTransferMethod] = Field(default_factory=list)
class ExternalDataVariableEntity(BaseModel): class ExternalDataVariableEntity(BaseModel):
@ -136,7 +140,7 @@ class DatasetRetrieveConfigEntity(BaseModel):
MULTIPLE = "multiple" MULTIPLE = "multiple"
@classmethod @classmethod
def value_of(cls, value: str) -> "RetrieveStrategy": def value_of(cls, value: str):
""" """
Get value of given mode. Get value of given mode.

@ -1,12 +1,13 @@
from collections.abc import Mapping from collections.abc import Mapping
from typing import Any, Optional from typing import Any
from core.file.file_obj import FileExtraConfig from core.file.models import FileExtraConfig
from models import FileUploadConfig
class FileUploadConfigManager: class FileUploadConfigManager:
@classmethod @classmethod
def convert(cls, config: Mapping[str, Any], is_vision: bool = True) -> Optional[FileExtraConfig]: def convert(cls, config: Mapping[str, Any], is_vision: bool = True):
""" """
Convert model config to model config Convert model config to model config
@ -15,19 +16,18 @@ class FileUploadConfigManager:
""" """
file_upload_dict = config.get("file_upload") file_upload_dict = config.get("file_upload")
if file_upload_dict: if file_upload_dict:
if file_upload_dict.get("image"): if file_upload_dict.get("enabled"):
if "enabled" in file_upload_dict["image"] and file_upload_dict["image"]["enabled"]: data = {
image_config = { "image_config": {
"number_limits": file_upload_dict["image"]["number_limits"], "number_limits": file_upload_dict["number_limits"],
"transfer_methods": file_upload_dict["image"]["transfer_methods"], "transfer_methods": file_upload_dict["allowed_file_upload_methods"],
}
} }
if is_vision: if is_vision:
image_config["detail"] = file_upload_dict["image"]["detail"] data["image_config"]["detail"] = file_upload_dict.get("image", {}).get("detail", "low")
return FileExtraConfig(image_config=image_config)
return None return FileExtraConfig.model_validate(data)
@classmethod @classmethod
def validate_and_set_defaults(cls, config: dict, is_vision: bool = True) -> tuple[dict, list[str]]: def validate_and_set_defaults(cls, config: dict, is_vision: bool = True) -> tuple[dict, list[str]]:
@ -39,29 +39,7 @@ class FileUploadConfigManager:
""" """
if not config.get("file_upload"): if not config.get("file_upload"):
config["file_upload"] = {} config["file_upload"] = {}
else:
if not isinstance(config["file_upload"], dict): FileUploadConfig.model_validate(config["file_upload"])
raise ValueError("file_upload must be of dict type")
# check image config
if not config["file_upload"].get("image"):
config["file_upload"]["image"] = {"enabled": False}
if config["file_upload"]["image"]["enabled"]:
number_limits = config["file_upload"]["image"]["number_limits"]
if number_limits < 1 or number_limits > 6:
raise ValueError("number_limits must be in [1, 6]")
if is_vision:
detail = config["file_upload"]["image"]["detail"]
if detail not in {"high", "low"}:
raise ValueError("detail must be in ['high', 'low']")
transfer_methods = config["file_upload"]["image"]["transfer_methods"]
if not isinstance(transfer_methods, list):
raise ValueError("transfer_methods must be of list type")
for method in transfer_methods:
if method not in {"remote_url", "local_file"}:
raise ValueError("transfer_methods must be in ['remote_url', 'local_file']")
return config, ["file_upload"] return config, ["file_upload"]

@ -17,6 +17,6 @@ class WorkflowVariablesConfigManager:
# variables # variables
for variable in user_input_form: for variable in user_input_form:
variables.append(VariableEntity(**variable)) variables.append(VariableEntity.model_validate(variable))
return variables return variables

@ -21,11 +21,12 @@ from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom
from core.app.entities.task_entities import ChatbotAppBlockingResponse, ChatbotAppStreamResponse from core.app.entities.task_entities import ChatbotAppBlockingResponse, ChatbotAppStreamResponse
from core.file.message_file_parser import MessageFileParser
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
from core.ops.ops_trace_manager import TraceQueueManager from core.ops.ops_trace_manager import TraceQueueManager
from extensions.ext_database import db from extensions.ext_database import db
from factories import file_factory
from models.account import Account from models.account import Account
from models.enums import CreatedByRole
from models.model import App, Conversation, EndUser, Message from models.model import App, Conversation, EndUser, Message
from models.workflow import Workflow from models.workflow import Workflow
@ -107,10 +108,16 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
# parse files # parse files
files = args["files"] if args.get("files") else [] files = args["files"] if args.get("files") else []
message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id)
file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False) file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False)
role = CreatedByRole.ACCOUNT if isinstance(user, Account) else CreatedByRole.END_USER
if file_extra_config: if file_extra_config:
file_objs = message_file_parser.validate_and_transform_files_arg(files, file_extra_config, user) file_objs = file_factory.build_from_mappings(
mappings=files,
tenant_id=app_model.tenant_id,
user_id=user.id,
role=role,
config=file_extra_config,
)
else: else:
file_objs = [] file_objs = []
@ -118,8 +125,9 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
app_config = AdvancedChatAppConfigManager.get_app_config(app_model=app_model, workflow=workflow) app_config = AdvancedChatAppConfigManager.get_app_config(app_model=app_model, workflow=workflow)
# get tracing instance # get tracing instance
user_id = user.id if isinstance(user, Account) else user.session_id trace_manager = TraceQueueManager(
trace_manager = TraceQueueManager(app_model.id, user_id) app_id=app_model.id, user_id=user.id if isinstance(user, Account) else user.session_id
)
if invoke_from == InvokeFrom.DEBUGGER: if invoke_from == InvokeFrom.DEBUGGER:
# always enable retriever resource in debugger mode # always enable retriever resource in debugger mode
@ -131,7 +139,9 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
task_id=str(uuid.uuid4()), task_id=str(uuid.uuid4()),
app_config=app_config, app_config=app_config,
conversation_id=conversation.id if conversation else None, conversation_id=conversation.id if conversation else None,
inputs=conversation.inputs if conversation else self._get_cleaned_inputs(inputs, app_config), inputs=conversation.inputs
if conversation
else self._prepare_user_inputs(user_inputs=inputs, app_config=app_config, user_id=user.id, role=role),
query=query, query=query,
files=file_objs, files=file_objs,
parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL, parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL,

@ -1,31 +1,27 @@
import logging import logging
import os
from collections.abc import Mapping from collections.abc import Mapping
from typing import Any, cast from typing import Any, cast
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from configs import dify_config
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig
from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner
from core.app.apps.workflow_logging_callback import WorkflowLoggingCallback from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom
from core.app.entities.app_invoke_entities import (
AdvancedChatAppGenerateEntity,
InvokeFrom,
)
from core.app.entities.queue_entities import ( from core.app.entities.queue_entities import (
QueueAnnotationReplyEvent, QueueAnnotationReplyEvent,
QueueStopEvent, QueueStopEvent,
QueueTextChunkEvent, QueueTextChunkEvent,
) )
from core.moderation.base import ModerationError from core.moderation.base import ModerationError
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback from core.workflow.callbacks import WorkflowCallback, WorkflowLoggingCallback
from core.workflow.entities.node_entities import UserFrom
from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariableKey from core.workflow.enums import SystemVariableKey
from core.workflow.workflow_entry import WorkflowEntry from core.workflow.workflow_entry import WorkflowEntry
from extensions.ext_database import db from extensions.ext_database import db
from models.enums import UserFrom
from models.model import App, Conversation, EndUser, Message from models.model import App, Conversation, EndUser, Message
from models.workflow import ConversationVariable, WorkflowType from models.workflow import ConversationVariable, WorkflowType
@ -44,12 +40,6 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
conversation: Conversation, conversation: Conversation,
message: Message, message: Message,
) -> None: ) -> None:
"""
:param application_generate_entity: application generate entity
:param queue_manager: application queue manager
:param conversation: conversation
:param message: message
"""
super().__init__(queue_manager) super().__init__(queue_manager)
self.application_generate_entity = application_generate_entity self.application_generate_entity = application_generate_entity
@ -57,10 +47,6 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
self.message = message self.message = message
def run(self) -> None: def run(self) -> None:
"""
Run application
:return:
"""
app_config = self.application_generate_entity.app_config app_config = self.application_generate_entity.app_config
app_config = cast(AdvancedChatAppConfig, app_config) app_config = cast(AdvancedChatAppConfig, app_config)
@ -81,7 +67,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
user_id = self.application_generate_entity.user_id user_id = self.application_generate_entity.user_id
workflow_callbacks: list[WorkflowCallback] = [] workflow_callbacks: list[WorkflowCallback] = []
if bool(os.environ.get("DEBUG", "False").lower() == "true"): if dify_config.DEBUG:
workflow_callbacks.append(WorkflowLoggingCallback()) workflow_callbacks.append(WorkflowLoggingCallback())
if self.application_generate_entity.single_iteration_run: if self.application_generate_entity.single_iteration_run:
@ -201,15 +187,6 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
query: str, query: str,
message_id: str, message_id: str,
) -> bool: ) -> bool:
"""
Handle input moderation
:param app_record: app record
:param app_generate_entity: application generate entity
:param inputs: inputs
:param query: query
:param message_id: message id
:return:
"""
try: try:
# process sensitive_word_avoidance # process sensitive_word_avoidance
_, inputs, query = self.moderation_for_inputs( _, inputs, query = self.moderation_for_inputs(
@ -229,14 +206,6 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
def handle_annotation_reply( def handle_annotation_reply(
self, app_record: App, message: Message, query: str, app_generate_entity: AdvancedChatAppGenerateEntity self, app_record: App, message: Message, query: str, app_generate_entity: AdvancedChatAppGenerateEntity
) -> bool: ) -> bool:
"""
Handle annotation reply
:param app_record: app record
:param message: message
:param query: query
:param app_generate_entity: application generate entity
"""
# annotation reply
annotation_reply = self.query_app_annotations_to_reply( annotation_reply = self.query_app_annotations_to_reply(
app_record=app_record, app_record=app_record,
message=message, message=message,
@ -258,8 +227,6 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
def _complete_with_stream_output(self, text: str, stopped_by: QueueStopEvent.StopBy) -> None: def _complete_with_stream_output(self, text: str, stopped_by: QueueStopEvent.StopBy) -> None:
""" """
Direct output Direct output
:param text: text
:return:
""" """
self._publish_event(QueueTextChunkEvent(text=text)) self._publish_event(QueueTextChunkEvent(text=text))

@ -1,7 +1,7 @@
import json import json
import logging import logging
import time import time
from collections.abc import Generator from collections.abc import Generator, Mapping
from typing import Any, Optional, Union from typing import Any, Optional, Union
from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME
@ -9,6 +9,7 @@ from core.app.apps.advanced_chat.app_generator_tts_publisher import AppGenerator
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.entities.app_invoke_entities import ( from core.app.entities.app_invoke_entities import (
AdvancedChatAppGenerateEntity, AdvancedChatAppGenerateEntity,
InvokeFrom,
) )
from core.app.entities.queue_entities import ( from core.app.entities.queue_entities import (
QueueAdvancedChatMessageEndEvent, QueueAdvancedChatMessageEndEvent,
@ -50,10 +51,12 @@ from core.model_runtime.utils.encoders import jsonable_encoder
from core.ops.ops_trace_manager import TraceQueueManager from core.ops.ops_trace_manager import TraceQueueManager
from core.workflow.enums import SystemVariableKey from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.nodes import NodeType
from events.message_event import message_was_created from events.message_event import message_was_created
from extensions.ext_database import db from extensions.ext_database import db
from models import Conversation, EndUser, Message, MessageFile
from models.account import Account from models.account import Account
from models.model import Conversation, EndUser, Message from models.enums import CreatedByRole
from models.workflow import ( from models.workflow import (
Workflow, Workflow,
WorkflowNodeExecution, WorkflowNodeExecution,
@ -120,6 +123,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
self._wip_workflow_node_executions = {} self._wip_workflow_node_executions = {}
self._conversation_name_generate_thread = None self._conversation_name_generate_thread = None
self._recorded_files: list[Mapping[str, Any]] = []
def process(self): def process(self):
""" """
@ -298,6 +302,10 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
elif isinstance(event, QueueNodeSucceededEvent): elif isinstance(event, QueueNodeSucceededEvent):
workflow_node_execution = self._handle_workflow_node_execution_success(event) workflow_node_execution = self._handle_workflow_node_execution_success(event)
# Record files if it's an answer node or end node
if event.node_type in [NodeType.ANSWER, NodeType.END]:
self._recorded_files.extend(self._fetch_files_from_node_outputs(event.outputs or {}))
response = self._workflow_node_finish_to_stream_response( response = self._workflow_node_finish_to_stream_response(
event=event, event=event,
task_id=self._application_generate_entity.task_id, task_id=self._application_generate_entity.task_id,
@ -364,7 +372,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
start_at=graph_runtime_state.start_at, start_at=graph_runtime_state.start_at,
total_tokens=graph_runtime_state.total_tokens, total_tokens=graph_runtime_state.total_tokens,
total_steps=graph_runtime_state.node_run_steps, total_steps=graph_runtime_state.node_run_steps,
outputs=json.dumps(event.outputs) if event.outputs else None, outputs=event.outputs,
conversation_id=self._conversation.id, conversation_id=self._conversation.id,
trace_manager=trace_manager, trace_manager=trace_manager,
) )
@ -490,10 +498,6 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
self._conversation_name_generate_thread.join() self._conversation_name_generate_thread.join()
def _save_message(self, graph_runtime_state: Optional[GraphRuntimeState] = None) -> None: def _save_message(self, graph_runtime_state: Optional[GraphRuntimeState] = None) -> None:
"""
Save message.
:return:
"""
self._refetch_message() self._refetch_message()
self._message.answer = self._task_state.answer self._message.answer = self._task_state.answer
@ -501,6 +505,22 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
self._message.message_metadata = ( self._message.message_metadata = (
json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None
) )
message_files = [
MessageFile(
message_id=self._message.id,
type=file["type"],
transfer_method=file["transfer_method"],
url=file["remote_url"],
belongs_to="assistant",
upload_file_id=file["related_id"],
created_by_role=CreatedByRole.ACCOUNT
if self._message.invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER}
else CreatedByRole.END_USER,
created_by=self._message.from_account_id or self._message.from_end_user_id or "",
)
for file in self._recorded_files
]
db.session.add_all(message_files)
if graph_runtime_state and graph_runtime_state.llm_usage: if graph_runtime_state and graph_runtime_state.llm_usage:
usage = graph_runtime_state.llm_usage usage = graph_runtime_state.llm_usage
@ -540,7 +560,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
del extras["metadata"]["annotation_reply"] del extras["metadata"]["annotation_reply"]
return MessageEndStreamResponse( return MessageEndStreamResponse(
task_id=self._application_generate_entity.task_id, id=self._message.id, **extras task_id=self._application_generate_entity.task_id, id=self._message.id, files=self._recorded_files, **extras
) )
def _handle_output_moderation_chunk(self, text: str) -> bool: def _handle_output_moderation_chunk(self, text: str) -> bool:

@ -18,12 +18,12 @@ from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskSt
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, InvokeFrom from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, InvokeFrom
from core.file.message_file_parser import MessageFileParser
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
from core.ops.ops_trace_manager import TraceQueueManager from core.ops.ops_trace_manager import TraceQueueManager
from extensions.ext_database import db from extensions.ext_database import db
from models.account import Account from factories import file_factory
from models.model import App, EndUser from models import Account, App, EndUser
from models.enums import CreatedByRole
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -108,12 +108,19 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
# always enable retriever resource in debugger mode # always enable retriever resource in debugger mode
override_model_config_dict["retriever_resource"] = {"enabled": True} override_model_config_dict["retriever_resource"] = {"enabled": True}
role = CreatedByRole.ACCOUNT if isinstance(user, Account) else CreatedByRole.END_USER
# parse files # parse files
files = args["files"] if args.get("files") else [] files = args.get("files") or []
message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id)
file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict()) file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict())
if file_extra_config: if file_extra_config:
file_objs = message_file_parser.validate_and_transform_files_arg(files, file_extra_config, user) file_objs = file_factory.build_from_mappings(
mappings=files,
tenant_id=app_model.tenant_id,
user_id=user.id,
role=role,
config=file_extra_config,
)
else: else:
file_objs = [] file_objs = []
@ -126,8 +133,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
) )
# get tracing instance # get tracing instance
user_id = user.id if isinstance(user, Account) else user.session_id trace_manager = TraceQueueManager(app_model.id, user.id if isinstance(user, Account) else user.session_id)
trace_manager = TraceQueueManager(app_model.id, user_id)
# init application generate entity # init application generate entity
application_generate_entity = AgentChatAppGenerateEntity( application_generate_entity = AgentChatAppGenerateEntity(
@ -135,7 +141,9 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
app_config=app_config, app_config=app_config,
model_conf=ModelConfigConverter.convert(app_config), model_conf=ModelConfigConverter.convert(app_config),
conversation_id=conversation.id if conversation else None, conversation_id=conversation.id if conversation else None,
inputs=conversation.inputs if conversation else self._get_cleaned_inputs(inputs, app_config), inputs=conversation.inputs
if conversation
else self._prepare_user_inputs(user_inputs=inputs, app_config=app_config, user_id=user.id, role=role),
query=query, query=query,
files=file_objs, files=file_objs,
parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL, parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL,

@ -1,36 +1,93 @@
import json import json
from collections.abc import Generator, Mapping from collections.abc import Generator, Mapping
from typing import Any, Optional, Union from typing import TYPE_CHECKING, Any, Optional, Union
from core.app.app_config.entities import AppConfig, VariableEntity, VariableEntityType from core.app.app_config.entities import VariableEntityType
from core.file import File, FileExtraConfig
from factories import file_factory
if TYPE_CHECKING:
from core.app.app_config.entities import AppConfig, VariableEntity
from models.enums import CreatedByRole
class BaseAppGenerator: class BaseAppGenerator:
def _get_cleaned_inputs(self, user_inputs: Optional[Mapping[str, Any]], app_config: AppConfig) -> Mapping[str, Any]: def _prepare_user_inputs(
self,
*,
user_inputs: Optional[Mapping[str, Any]],
app_config: "AppConfig",
user_id: str,
role: "CreatedByRole",
) -> Mapping[str, Any]:
user_inputs = user_inputs or {} user_inputs = user_inputs or {}
# Filter input variables from form configuration, handle required fields, default values, and option values # Filter input variables from form configuration, handle required fields, default values, and option values
variables = app_config.variables variables = app_config.variables
filtered_inputs = {var.variable: self._validate_input(inputs=user_inputs, var=var) for var in variables} user_inputs = {var.variable: self._validate_input(inputs=user_inputs, var=var) for var in variables}
filtered_inputs = {k: self._sanitize_value(v) for k, v in filtered_inputs.items()} user_inputs = {k: self._sanitize_value(v) for k, v in user_inputs.items()}
return filtered_inputs # Convert files in inputs to File
entity_dictionary = {item.variable: item for item in app_config.variables}
# Convert single file to File
files_inputs = {
k: file_factory.build_from_mapping(
mapping=v,
tenant_id=app_config.tenant_id,
user_id=user_id,
role=role,
config=FileExtraConfig(
allowed_file_types=entity_dictionary[k].allowed_file_types,
allowed_extensions=entity_dictionary[k].allowed_file_extensions,
allowed_upload_methods=entity_dictionary[k].allowed_file_upload_methods,
),
)
for k, v in user_inputs.items()
if isinstance(v, dict) and entity_dictionary[k].type == VariableEntityType.FILE
}
# Convert list of files to File
file_list_inputs = {
k: file_factory.build_from_mappings(
mappings=v,
tenant_id=app_config.tenant_id,
user_id=user_id,
role=role,
config=FileExtraConfig(
allowed_file_types=entity_dictionary[k].allowed_file_types,
allowed_extensions=entity_dictionary[k].allowed_file_extensions,
allowed_upload_methods=entity_dictionary[k].allowed_file_upload_methods,
),
)
for k, v in user_inputs.items()
if isinstance(v, list)
# Ensure skip List<File>
and all(isinstance(item, dict) for item in v)
and entity_dictionary[k].type == VariableEntityType.FILE_LIST
}
# Merge all inputs
user_inputs = {**user_inputs, **files_inputs, **file_list_inputs}
# Check if all files are converted to File
if any(filter(lambda v: isinstance(v, dict), user_inputs.values())):
raise ValueError("Invalid input type")
if any(
filter(lambda v: isinstance(v, dict), filter(lambda item: isinstance(item, list), user_inputs.values()))
):
raise ValueError("Invalid input type")
return user_inputs
def _validate_input(self, *, inputs: Mapping[str, Any], var: VariableEntity): def _validate_input(self, *, inputs: Mapping[str, Any], var: "VariableEntity"):
user_input_value = inputs.get(var.variable) user_input_value = inputs.get(var.variable)
if var.required and not user_input_value: if not user_input_value:
if var.required:
raise ValueError(f"{var.variable} is required in input form") raise ValueError(f"{var.variable} is required in input form")
if not var.required and not user_input_value: else:
# TODO: should we return None here if the default value is None? return None
return var.default or ""
if ( if var.type in {
var.type
in {
VariableEntityType.TEXT_INPUT, VariableEntityType.TEXT_INPUT,
VariableEntityType.SELECT, VariableEntityType.SELECT,
VariableEntityType.PARAGRAPH, VariableEntityType.PARAGRAPH,
} } and not isinstance(user_input_value, str):
and user_input_value
and not isinstance(user_input_value, str)
):
raise ValueError(f"(type '{var.type}') {var.variable} in input form must be a string") raise ValueError(f"(type '{var.type}') {var.variable} in input form must be a string")
if var.type == VariableEntityType.NUMBER and isinstance(user_input_value, str): if var.type == VariableEntityType.NUMBER and isinstance(user_input_value, str):
# may raise ValueError if user_input_value is not a valid number # may raise ValueError if user_input_value is not a valid number
@ -42,12 +99,24 @@ class BaseAppGenerator:
except ValueError: except ValueError:
raise ValueError(f"{var.variable} in input form must be a valid number") raise ValueError(f"{var.variable} in input form must be a valid number")
if var.type == VariableEntityType.SELECT: if var.type == VariableEntityType.SELECT:
options = var.options or [] options = var.options
if user_input_value not in options: if user_input_value not in options:
raise ValueError(f"{var.variable} in input form must be one of the following: {options}") raise ValueError(f"{var.variable} in input form must be one of the following: {options}")
elif var.type in {VariableEntityType.TEXT_INPUT, VariableEntityType.PARAGRAPH}: elif var.type in {VariableEntityType.TEXT_INPUT, VariableEntityType.PARAGRAPH}:
if var.max_length and user_input_value and len(user_input_value) > var.max_length: if var.max_length and len(user_input_value) > var.max_length:
raise ValueError(f"{var.variable} in input form must be less than {var.max_length} characters") raise ValueError(f"{var.variable} in input form must be less than {var.max_length} characters")
elif var.type == VariableEntityType.FILE:
if not isinstance(user_input_value, dict) and not isinstance(user_input_value, File):
raise ValueError(f"{var.variable} in input form must be a file")
elif var.type == VariableEntityType.FILE_LIST:
if not (
isinstance(user_input_value, list)
and (
all(isinstance(item, dict) for item in user_input_value)
or all(isinstance(item, File) for item in user_input_value)
)
):
raise ValueError(f"{var.variable} in input form must be a list of files")
return user_input_value return user_input_value

@ -27,7 +27,7 @@ from core.prompt.simple_prompt_transform import ModelMode, SimplePromptTransform
from models.model import App, AppMode, Message, MessageAnnotation from models.model import App, AppMode, Message, MessageAnnotation
if TYPE_CHECKING: if TYPE_CHECKING:
from core.file.file_obj import FileVar from core.file.models import File
class AppRunner: class AppRunner:
@ -37,7 +37,7 @@ class AppRunner:
model_config: ModelConfigWithCredentialsEntity, model_config: ModelConfigWithCredentialsEntity,
prompt_template_entity: PromptTemplateEntity, prompt_template_entity: PromptTemplateEntity,
inputs: dict[str, str], inputs: dict[str, str],
files: list["FileVar"], files: list["File"],
query: Optional[str] = None, query: Optional[str] = None,
) -> int: ) -> int:
""" """
@ -137,7 +137,7 @@ class AppRunner:
model_config: ModelConfigWithCredentialsEntity, model_config: ModelConfigWithCredentialsEntity,
prompt_template_entity: PromptTemplateEntity, prompt_template_entity: PromptTemplateEntity,
inputs: dict[str, str], inputs: dict[str, str],
files: list["FileVar"], files: list["File"],
query: Optional[str] = None, query: Optional[str] = None,
context: Optional[str] = None, context: Optional[str] = None,
memory: Optional[TokenBufferMemory] = None, memory: Optional[TokenBufferMemory] = None,

@ -18,11 +18,12 @@ from core.app.apps.chat.generate_response_converter import ChatAppGenerateRespon
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
from core.app.entities.app_invoke_entities import ChatAppGenerateEntity, InvokeFrom from core.app.entities.app_invoke_entities import ChatAppGenerateEntity, InvokeFrom
from core.file.message_file_parser import MessageFileParser
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
from core.ops.ops_trace_manager import TraceQueueManager from core.ops.ops_trace_manager import TraceQueueManager
from extensions.ext_database import db from extensions.ext_database import db
from factories import file_factory
from models.account import Account from models.account import Account
from models.enums import CreatedByRole
from models.model import App, EndUser from models.model import App, EndUser
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -110,12 +111,19 @@ class ChatAppGenerator(MessageBasedAppGenerator):
# always enable retriever resource in debugger mode # always enable retriever resource in debugger mode
override_model_config_dict["retriever_resource"] = {"enabled": True} override_model_config_dict["retriever_resource"] = {"enabled": True}
role = CreatedByRole.ACCOUNT if isinstance(user, Account) else CreatedByRole.END_USER
# parse files # parse files
files = args["files"] if args.get("files") else [] files = args["files"] if args.get("files") else []
message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id)
file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict()) file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict())
if file_extra_config: if file_extra_config:
file_objs = message_file_parser.validate_and_transform_files_arg(files, file_extra_config, user) file_objs = file_factory.build_from_mappings(
mappings=files,
tenant_id=app_model.tenant_id,
user_id=user.id,
role=role,
config=file_extra_config,
)
else: else:
file_objs = [] file_objs = []
@ -128,7 +136,7 @@ class ChatAppGenerator(MessageBasedAppGenerator):
) )
# get tracing instance # get tracing instance
trace_manager = TraceQueueManager(app_model.id) trace_manager = TraceQueueManager(app_id=app_model.id)
# init application generate entity # init application generate entity
application_generate_entity = ChatAppGenerateEntity( application_generate_entity = ChatAppGenerateEntity(
@ -136,15 +144,17 @@ class ChatAppGenerator(MessageBasedAppGenerator):
app_config=app_config, app_config=app_config,
model_conf=ModelConfigConverter.convert(app_config), model_conf=ModelConfigConverter.convert(app_config),
conversation_id=conversation.id if conversation else None, conversation_id=conversation.id if conversation else None,
inputs=conversation.inputs if conversation else self._get_cleaned_inputs(inputs, app_config), inputs=conversation.inputs
if conversation
else self._prepare_user_inputs(user_inputs=inputs, app_config=app_config, user_id=user.id, role=role),
query=query, query=query,
files=file_objs, files=file_objs,
parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL, parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL,
user_id=user.id, user_id=user.id,
stream=stream,
invoke_from=invoke_from, invoke_from=invoke_from,
extras=extras, extras=extras,
trace_manager=trace_manager, trace_manager=trace_manager,
stream=stream,
) )
# init generate records # init generate records

@ -17,12 +17,12 @@ from core.app.apps.completion.generate_response_converter import CompletionAppGe
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
from core.app.entities.app_invoke_entities import CompletionAppGenerateEntity, InvokeFrom from core.app.entities.app_invoke_entities import CompletionAppGenerateEntity, InvokeFrom
from core.file.message_file_parser import MessageFileParser
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
from core.ops.ops_trace_manager import TraceQueueManager from core.ops.ops_trace_manager import TraceQueueManager
from extensions.ext_database import db from extensions.ext_database import db
from models.account import Account from factories import file_factory
from models.model import App, EndUser, Message from models import Account, App, EndUser, Message
from models.enums import CreatedByRole
from services.errors.app import MoreLikeThisDisabledError from services.errors.app import MoreLikeThisDisabledError
from services.errors.message import MessageNotExistsError from services.errors.message import MessageNotExistsError
@ -98,12 +98,19 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
tenant_id=app_model.tenant_id, config=args.get("model_config") tenant_id=app_model.tenant_id, config=args.get("model_config")
) )
role = CreatedByRole.ACCOUNT if isinstance(user, Account) else CreatedByRole.END_USER
# parse files # parse files
files = args["files"] if args.get("files") else [] files = args["files"] if args.get("files") else []
message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id)
file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict()) file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict())
if file_extra_config: if file_extra_config:
file_objs = message_file_parser.validate_and_transform_files_arg(files, file_extra_config, user) file_objs = file_factory.build_from_mappings(
mappings=files,
tenant_id=app_model.tenant_id,
user_id=user.id,
role=role,
config=file_extra_config,
)
else: else:
file_objs = [] file_objs = []
@ -113,6 +120,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
) )
# get tracing instance # get tracing instance
user_id = user.id if isinstance(user, Account) else user.session_id
trace_manager = TraceQueueManager(app_model.id) trace_manager = TraceQueueManager(app_model.id)
# init application generate entity # init application generate entity
@ -120,7 +128,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
task_id=str(uuid.uuid4()), task_id=str(uuid.uuid4()),
app_config=app_config, app_config=app_config,
model_conf=ModelConfigConverter.convert(app_config), model_conf=ModelConfigConverter.convert(app_config),
inputs=self._get_cleaned_inputs(inputs, app_config), inputs=self._prepare_user_inputs(user_inputs=inputs, app_config=app_config, user_id=user.id, role=role),
query=query, query=query,
files=file_objs, files=file_objs,
user_id=user.id, user_id=user.id,
@ -261,10 +269,16 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
override_model_config_dict["model"] = model_dict override_model_config_dict["model"] = model_dict
# parse files # parse files
message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id) role = CreatedByRole.ACCOUNT if isinstance(user, Account) else CreatedByRole.END_USER
file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict()) file_extra_config = FileUploadConfigManager.convert(override_model_config_dict)
if file_extra_config: if file_extra_config:
file_objs = message_file_parser.validate_and_transform_files_arg(message.files, file_extra_config, user) file_objs = file_factory.build_from_mappings(
mappings=message.message_files,
tenant_id=app_model.tenant_id,
user_id=user.id,
role=role,
config=file_extra_config,
)
else: else:
file_objs = [] file_objs = []

@ -26,7 +26,7 @@ from core.app.entities.task_entities import (
from core.app.task_pipeline.easy_ui_based_generate_task_pipeline import EasyUIBasedGenerateTaskPipeline from core.app.task_pipeline.easy_ui_based_generate_task_pipeline import EasyUIBasedGenerateTaskPipeline
from core.prompt.utils.prompt_template_parser import PromptTemplateParser from core.prompt.utils.prompt_template_parser import PromptTemplateParser
from extensions.ext_database import db from extensions.ext_database import db
from models.account import Account from models import Account
from models.model import App, AppMode, AppModelConfig, Conversation, EndUser, Message, MessageFile from models.model import App, AppMode, AppModelConfig, Conversation, EndUser, Message, MessageFile
from services.errors.app_model_config import AppModelConfigBrokenError from services.errors.app_model_config import AppModelConfigBrokenError
from services.errors.conversation import ConversationCompletedError, ConversationNotExistsError from services.errors.conversation import ConversationCompletedError, ConversationNotExistsError
@ -235,13 +235,13 @@ class MessageBasedAppGenerator(BaseAppGenerator):
for file in application_generate_entity.files: for file in application_generate_entity.files:
message_file = MessageFile( message_file = MessageFile(
message_id=message.id, message_id=message.id,
type=file.type.value, type=file.type,
transfer_method=file.transfer_method.value, transfer_method=file.transfer_method,
belongs_to="user", belongs_to="user",
url=file.url, url=file.remote_url,
upload_file_id=file.related_id, upload_file_id=file.related_id,
created_by_role=("account" if account_id else "end_user"), created_by_role=("account" if account_id else "end_user"),
created_by=account_id or end_user_id, created_by=account_id or end_user_id or "",
) )
db.session.add(message_file) db.session.add(message_file)
db.session.commit() db.session.commit()

@ -3,7 +3,7 @@ import logging
import os import os
import threading import threading
import uuid import uuid
from collections.abc import Generator from collections.abc import Generator, Mapping, Sequence
from typing import Any, Literal, Optional, Union, overload from typing import Any, Literal, Optional, Union, overload
from flask import Flask, current_app from flask import Flask, current_app
@ -20,13 +20,12 @@ from core.app.apps.workflow.generate_response_converter import WorkflowAppGenera
from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse
from core.file.message_file_parser import MessageFileParser
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
from core.ops.ops_trace_manager import TraceQueueManager from core.ops.ops_trace_manager import TraceQueueManager
from extensions.ext_database import db from extensions.ext_database import db
from models.account import Account from factories import file_factory
from models.model import App, EndUser from models import Account, App, EndUser, Workflow
from models.workflow import Workflow from models.enums import CreatedByRole
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -75,49 +74,46 @@ class WorkflowAppGenerator(BaseAppGenerator):
app_model: App, app_model: App,
workflow: Workflow, workflow: Workflow,
user: Union[Account, EndUser], user: Union[Account, EndUser],
args: dict, args: Mapping[str, Any],
invoke_from: InvokeFrom, invoke_from: InvokeFrom,
stream: bool = True, stream: bool = True,
call_depth: int = 0, call_depth: int = 0,
workflow_thread_pool_id: Optional[str] = None, workflow_thread_pool_id: Optional[str] = None,
): ):
""" files: Sequence[Mapping[str, Any]] = args.get("files") or []
Generate App response.
:param app_model: App role = CreatedByRole.ACCOUNT if isinstance(user, Account) else CreatedByRole.END_USER
:param workflow: Workflow
:param user: account or end user
:param args: request args
:param invoke_from: invoke from source
:param stream: is stream
:param call_depth: call depth
:param workflow_thread_pool_id: workflow thread pool id
"""
inputs = args["inputs"]
# parse files # parse files
files = args["files"] if args.get("files") else []
message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id)
file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False) file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False)
if file_extra_config: system_files = file_factory.build_from_mappings(
file_objs = message_file_parser.validate_and_transform_files_arg(files, file_extra_config, user) mappings=files,
else: tenant_id=app_model.tenant_id,
file_objs = [] user_id=user.id,
role=role,
config=file_extra_config,
)
# convert to app config # convert to app config
app_config = WorkflowAppConfigManager.get_app_config(app_model=app_model, workflow=workflow) app_config = WorkflowAppConfigManager.get_app_config(
app_model=app_model,
workflow=workflow,
)
# get tracing instance # get tracing instance
user_id = user.id if isinstance(user, Account) else user.session_id trace_manager = TraceQueueManager(
trace_manager = TraceQueueManager(app_model.id, user_id) app_id=app_model.id,
user_id=user.id if isinstance(user, Account) else user.session_id,
)
inputs: Mapping[str, Any] = args["inputs"]
workflow_run_id = str(uuid.uuid4()) workflow_run_id = str(uuid.uuid4())
# init application generate entity # init application generate entity
application_generate_entity = WorkflowAppGenerateEntity( application_generate_entity = WorkflowAppGenerateEntity(
task_id=str(uuid.uuid4()), task_id=str(uuid.uuid4()),
app_config=app_config, app_config=app_config,
inputs=self._get_cleaned_inputs(inputs, app_config), inputs=self._prepare_user_inputs(user_inputs=inputs, app_config=app_config, user_id=user.id, role=role),
files=file_objs, files=system_files,
user_id=user.id, user_id=user.id,
stream=stream, stream=stream,
invoke_from=invoke_from, invoke_from=invoke_from,

@ -1,21 +1,20 @@
import logging import logging
import os
from typing import Optional, cast from typing import Optional, cast
from configs import dify_config
from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.apps.workflow.app_config_manager import WorkflowAppConfig from core.app.apps.workflow.app_config_manager import WorkflowAppConfig
from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner
from core.app.apps.workflow_logging_callback import WorkflowLoggingCallback
from core.app.entities.app_invoke_entities import ( from core.app.entities.app_invoke_entities import (
InvokeFrom, InvokeFrom,
WorkflowAppGenerateEntity, WorkflowAppGenerateEntity,
) )
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback from core.workflow.callbacks import WorkflowCallback, WorkflowLoggingCallback
from core.workflow.entities.node_entities import UserFrom
from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariableKey from core.workflow.enums import SystemVariableKey
from core.workflow.workflow_entry import WorkflowEntry from core.workflow.workflow_entry import WorkflowEntry
from extensions.ext_database import db from extensions.ext_database import db
from models.enums import UserFrom
from models.model import App, EndUser from models.model import App, EndUser
from models.workflow import WorkflowType from models.workflow import WorkflowType
@ -71,7 +70,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
db.session.close() db.session.close()
workflow_callbacks: list[WorkflowCallback] = [] workflow_callbacks: list[WorkflowCallback] = []
if bool(os.environ.get("DEBUG", "False").lower() == "true"): if dify_config.DEBUG:
workflow_callbacks.append(WorkflowLoggingCallback()) workflow_callbacks.append(WorkflowLoggingCallback())
# if only single iteration run is requested # if only single iteration run is requested

@ -1,4 +1,3 @@
import json
import logging import logging
import time import time
from collections.abc import Generator from collections.abc import Generator
@ -334,9 +333,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
start_at=graph_runtime_state.start_at, start_at=graph_runtime_state.start_at,
total_tokens=graph_runtime_state.total_tokens, total_tokens=graph_runtime_state.total_tokens,
total_steps=graph_runtime_state.node_run_steps, total_steps=graph_runtime_state.node_run_steps,
outputs=json.dumps(event.outputs) outputs=event.outputs,
if isinstance(event, QueueWorkflowSucceededEvent) and event.outputs
else None,
conversation_id=None, conversation_id=None,
trace_manager=trace_manager, trace_manager=trace_manager,
) )

@ -20,7 +20,6 @@ from core.app.entities.queue_entities import (
QueueWorkflowStartedEvent, QueueWorkflowStartedEvent,
QueueWorkflowSucceededEvent, QueueWorkflowSucceededEvent,
) )
from core.workflow.entities.node_entities import NodeType
from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.variable_pool import VariablePool
from core.workflow.graph_engine.entities.event import ( from core.workflow.graph_engine.entities.event import (
GraphEngineEvent, GraphEngineEvent,
@ -41,9 +40,9 @@ from core.workflow.graph_engine.entities.event import (
ParallelBranchRunSucceededEvent, ParallelBranchRunSucceededEvent,
) )
from core.workflow.graph_engine.entities.graph import Graph from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.nodes.base_node import BaseNode from core.workflow.nodes import NodeType
from core.workflow.nodes.iteration.entities import IterationNodeData from core.workflow.nodes.iteration import IterationNodeData
from core.workflow.nodes.node_mapping import node_classes from core.workflow.nodes.node_mapping import node_type_classes_mapping
from core.workflow.workflow_entry import WorkflowEntry from core.workflow.workflow_entry import WorkflowEntry
from extensions.ext_database import db from extensions.ext_database import db
from models.model import App from models.model import App
@ -137,9 +136,8 @@ class WorkflowBasedAppRunner(AppRunner):
raise ValueError("iteration node id not found in workflow graph") raise ValueError("iteration node id not found in workflow graph")
# Get node class # Get node class
node_type = NodeType.value_of(iteration_node_config.get("data", {}).get("type")) node_type = NodeType(iteration_node_config.get("data", {}).get("type"))
node_cls = node_classes.get(node_type) node_cls = node_type_classes_mapping[node_type]
node_cls = cast(type[BaseNode], node_cls)
# init variable pool # init variable pool
variable_pool = VariablePool( variable_pool = VariablePool(

@ -1,4 +1,4 @@
from collections.abc import Mapping from collections.abc import Mapping, Sequence
from enum import Enum from enum import Enum
from typing import Any, Optional from typing import Any, Optional
@ -7,7 +7,7 @@ from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validat
from constants import UUID_NIL from constants import UUID_NIL
from core.app.app_config.entities import AppConfig, EasyUIBasedAppConfig, WorkflowUIBasedAppConfig from core.app.app_config.entities import AppConfig, EasyUIBasedAppConfig, WorkflowUIBasedAppConfig
from core.entities.provider_configuration import ProviderModelBundle from core.entities.provider_configuration import ProviderModelBundle
from core.file.file_obj import FileVar from core.file.models import File
from core.model_runtime.entities.model_entities import AIModelEntity from core.model_runtime.entities.model_entities import AIModelEntity
from core.ops.ops_trace_manager import TraceQueueManager from core.ops.ops_trace_manager import TraceQueueManager
@ -23,7 +23,7 @@ class InvokeFrom(Enum):
DEBUGGER = "debugger" DEBUGGER = "debugger"
@classmethod @classmethod
def value_of(cls, value: str) -> "InvokeFrom": def value_of(cls, value: str):
""" """
Get value of given mode. Get value of given mode.
@ -82,7 +82,7 @@ class AppGenerateEntity(BaseModel):
app_config: AppConfig app_config: AppConfig
inputs: Mapping[str, Any] inputs: Mapping[str, Any]
files: list[FileVar] = [] files: Sequence[File]
user_id: str user_id: str
# extras # extras

@ -5,9 +5,10 @@ from typing import Any, Optional
from pydantic import BaseModel, field_validator from pydantic import BaseModel, field_validator
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.node_entities import NodeRunMetadataKey
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeType
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.nodes import NodeType
from core.workflow.nodes.base import BaseNodeData
class QueueEvent(str, Enum): class QueueEvent(str, Enum):

@ -1,3 +1,4 @@
from collections.abc import Mapping, Sequence
from enum import Enum from enum import Enum
from typing import Any, Optional from typing import Any, Optional
@ -119,6 +120,7 @@ class MessageEndStreamResponse(StreamResponse):
event: StreamEvent = StreamEvent.MESSAGE_END event: StreamEvent = StreamEvent.MESSAGE_END
id: str id: str
metadata: dict = {} metadata: dict = {}
files: Optional[Sequence[Mapping[str, Any]]] = None
class MessageFileStreamResponse(StreamResponse): class MessageFileStreamResponse(StreamResponse):
@ -211,7 +213,7 @@ class WorkflowFinishStreamResponse(StreamResponse):
created_by: Optional[dict] = None created_by: Optional[dict] = None
created_at: int created_at: int
finished_at: int finished_at: int
files: Optional[list[dict]] = [] files: Optional[Sequence[Mapping[str, Any]]] = []
event: StreamEvent = StreamEvent.WORKFLOW_FINISHED event: StreamEvent = StreamEvent.WORKFLOW_FINISHED
workflow_run_id: str workflow_run_id: str
@ -296,7 +298,7 @@ class NodeFinishStreamResponse(StreamResponse):
execution_metadata: Optional[dict] = None execution_metadata: Optional[dict] = None
created_at: int created_at: int
finished_at: int finished_at: int
files: Optional[list[dict]] = [] files: Optional[Sequence[Mapping[str, Any]]] = []
parallel_id: Optional[str] = None parallel_id: Optional[str] = None
parallel_start_node_id: Optional[str] = None parallel_start_node_id: Optional[str] = None
parent_parallel_id: Optional[str] = None parent_parallel_id: Optional[str] = None

@ -1,18 +0,0 @@
import re
from core.workflow.entities.variable_pool import VariablePool
from . import SegmentGroup, factory
VARIABLE_PATTERN = re.compile(r"\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10})#\}\}")
def convert_template(*, template: str, variable_pool: VariablePool):
parts = re.split(VARIABLE_PATTERN, template)
segments = []
for part in filter(lambda x: x, parts):
if "." in part and (value := variable_pool.get(part.split("."))):
segments.append(value)
else:
segments.append(factory.build_segment(part))
return SegmentGroup(value=segments)

@ -1,5 +1,6 @@
import json import json
import time import time
from collections.abc import Mapping, Sequence
from datetime import datetime, timezone from datetime import datetime, timezone
from typing import Any, Optional, Union, cast from typing import Any, Optional, Union, cast
@ -27,27 +28,26 @@ from core.app.entities.task_entities import (
WorkflowStartStreamResponse, WorkflowStartStreamResponse,
WorkflowTaskState, WorkflowTaskState,
) )
from core.file.file_obj import FileVar from core.file import FILE_MODEL_IDENTITY, File
from core.model_runtime.utils.encoders import jsonable_encoder from core.model_runtime.utils.encoders import jsonable_encoder
from core.ops.entities.trace_entity import TraceTaskName from core.ops.entities.trace_entity import TraceTaskName
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
from core.tools.tool_manager import ToolManager from core.tools.tool_manager import ToolManager
from core.workflow.entities.node_entities import NodeType
from core.workflow.enums import SystemVariableKey from core.workflow.enums import SystemVariableKey
from core.workflow.nodes import NodeType
from core.workflow.nodes.tool.entities import ToolNodeData from core.workflow.nodes.tool.entities import ToolNodeData
from core.workflow.workflow_entry import WorkflowEntry from core.workflow.workflow_entry import WorkflowEntry
from extensions.ext_database import db from extensions.ext_database import db
from models.account import Account from models.account import Account
from models.enums import CreatedByRole, WorkflowRunTriggeredFrom
from models.model import EndUser from models.model import EndUser
from models.workflow import ( from models.workflow import (
CreatedByRole,
Workflow, Workflow,
WorkflowNodeExecution, WorkflowNodeExecution,
WorkflowNodeExecutionStatus, WorkflowNodeExecutionStatus,
WorkflowNodeExecutionTriggeredFrom, WorkflowNodeExecutionTriggeredFrom,
WorkflowRun, WorkflowRun,
WorkflowRunStatus, WorkflowRunStatus,
WorkflowRunTriggeredFrom,
) )
@ -117,7 +117,7 @@ class WorkflowCycleManage:
start_at: float, start_at: float,
total_tokens: int, total_tokens: int,
total_steps: int, total_steps: int,
outputs: Optional[str] = None, outputs: Mapping[str, Any] | None = None,
conversation_id: Optional[str] = None, conversation_id: Optional[str] = None,
trace_manager: Optional[TraceQueueManager] = None, trace_manager: Optional[TraceQueueManager] = None,
) -> WorkflowRun: ) -> WorkflowRun:
@ -133,8 +133,10 @@ class WorkflowCycleManage:
""" """
workflow_run = self._refetch_workflow_run(workflow_run.id) workflow_run = self._refetch_workflow_run(workflow_run.id)
outputs = WorkflowEntry.handle_special_values(outputs)
workflow_run.status = WorkflowRunStatus.SUCCEEDED.value workflow_run.status = WorkflowRunStatus.SUCCEEDED.value
workflow_run.outputs = outputs workflow_run.outputs = json.dumps(outputs or {})
workflow_run.elapsed_time = time.perf_counter() - start_at workflow_run.elapsed_time = time.perf_counter() - start_at
workflow_run.total_tokens = total_tokens workflow_run.total_tokens = total_tokens
workflow_run.total_steps = total_steps workflow_run.total_steps = total_steps
@ -265,6 +267,7 @@ class WorkflowCycleManage:
workflow_node_execution = self._refetch_workflow_node_execution(event.node_execution_id) workflow_node_execution = self._refetch_workflow_node_execution(event.node_execution_id)
inputs = WorkflowEntry.handle_special_values(event.inputs) inputs = WorkflowEntry.handle_special_values(event.inputs)
process_data = WorkflowEntry.handle_special_values(event.process_data)
outputs = WorkflowEntry.handle_special_values(event.outputs) outputs = WorkflowEntry.handle_special_values(event.outputs)
execution_metadata = ( execution_metadata = (
json.dumps(jsonable_encoder(event.execution_metadata)) if event.execution_metadata else None json.dumps(jsonable_encoder(event.execution_metadata)) if event.execution_metadata else None
@ -276,7 +279,7 @@ class WorkflowCycleManage:
{ {
WorkflowNodeExecution.status: WorkflowNodeExecutionStatus.SUCCEEDED.value, WorkflowNodeExecution.status: WorkflowNodeExecutionStatus.SUCCEEDED.value,
WorkflowNodeExecution.inputs: json.dumps(inputs) if inputs else None, WorkflowNodeExecution.inputs: json.dumps(inputs) if inputs else None,
WorkflowNodeExecution.process_data: json.dumps(event.process_data) if event.process_data else None, WorkflowNodeExecution.process_data: json.dumps(process_data) if event.process_data else None,
WorkflowNodeExecution.outputs: json.dumps(outputs) if outputs else None, WorkflowNodeExecution.outputs: json.dumps(outputs) if outputs else None,
WorkflowNodeExecution.execution_metadata: execution_metadata, WorkflowNodeExecution.execution_metadata: execution_metadata,
WorkflowNodeExecution.finished_at: finished_at, WorkflowNodeExecution.finished_at: finished_at,
@ -286,10 +289,11 @@ class WorkflowCycleManage:
db.session.commit() db.session.commit()
db.session.close() db.session.close()
process_data = WorkflowEntry.handle_special_values(event.process_data)
workflow_node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED.value workflow_node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED.value
workflow_node_execution.inputs = json.dumps(inputs) if inputs else None workflow_node_execution.inputs = json.dumps(inputs) if inputs else None
workflow_node_execution.process_data = json.dumps(event.process_data) if event.process_data else None workflow_node_execution.process_data = json.dumps(process_data) if process_data else None
workflow_node_execution.outputs = json.dumps(outputs) if outputs else None workflow_node_execution.outputs = json.dumps(outputs) if outputs else None
workflow_node_execution.execution_metadata = execution_metadata workflow_node_execution.execution_metadata = execution_metadata
workflow_node_execution.finished_at = finished_at workflow_node_execution.finished_at = finished_at
@ -308,6 +312,7 @@ class WorkflowCycleManage:
workflow_node_execution = self._refetch_workflow_node_execution(event.node_execution_id) workflow_node_execution = self._refetch_workflow_node_execution(event.node_execution_id)
inputs = WorkflowEntry.handle_special_values(event.inputs) inputs = WorkflowEntry.handle_special_values(event.inputs)
process_data = WorkflowEntry.handle_special_values(event.process_data)
outputs = WorkflowEntry.handle_special_values(event.outputs) outputs = WorkflowEntry.handle_special_values(event.outputs)
finished_at = datetime.now(timezone.utc).replace(tzinfo=None) finished_at = datetime.now(timezone.utc).replace(tzinfo=None)
elapsed_time = (finished_at - event.start_at).total_seconds() elapsed_time = (finished_at - event.start_at).total_seconds()
@ -317,7 +322,7 @@ class WorkflowCycleManage:
WorkflowNodeExecution.status: WorkflowNodeExecutionStatus.FAILED.value, WorkflowNodeExecution.status: WorkflowNodeExecutionStatus.FAILED.value,
WorkflowNodeExecution.error: event.error, WorkflowNodeExecution.error: event.error,
WorkflowNodeExecution.inputs: json.dumps(inputs) if inputs else None, WorkflowNodeExecution.inputs: json.dumps(inputs) if inputs else None,
WorkflowNodeExecution.process_data: json.dumps(event.process_data) if event.process_data else None, WorkflowNodeExecution.process_data: json.dumps(process_data) if event.process_data else None,
WorkflowNodeExecution.outputs: json.dumps(outputs) if outputs else None, WorkflowNodeExecution.outputs: json.dumps(outputs) if outputs else None,
WorkflowNodeExecution.finished_at: finished_at, WorkflowNodeExecution.finished_at: finished_at,
WorkflowNodeExecution.elapsed_time: elapsed_time, WorkflowNodeExecution.elapsed_time: elapsed_time,
@ -326,11 +331,12 @@ class WorkflowCycleManage:
db.session.commit() db.session.commit()
db.session.close() db.session.close()
process_data = WorkflowEntry.handle_special_values(event.process_data)
workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value
workflow_node_execution.error = event.error workflow_node_execution.error = event.error
workflow_node_execution.inputs = json.dumps(inputs) if inputs else None workflow_node_execution.inputs = json.dumps(inputs) if inputs else None
workflow_node_execution.process_data = json.dumps(event.process_data) if event.process_data else None workflow_node_execution.process_data = json.dumps(process_data) if process_data else None
workflow_node_execution.outputs = json.dumps(outputs) if outputs else None workflow_node_execution.outputs = json.dumps(outputs) if outputs else None
workflow_node_execution.finished_at = finished_at workflow_node_execution.finished_at = finished_at
workflow_node_execution.elapsed_time = elapsed_time workflow_node_execution.elapsed_time = elapsed_time
@ -637,7 +643,7 @@ class WorkflowCycleManage:
), ),
) )
def _fetch_files_from_node_outputs(self, outputs_dict: dict) -> list[dict]: def _fetch_files_from_node_outputs(self, outputs_dict: dict) -> Sequence[Mapping[str, Any]]:
""" """
Fetch files from node outputs Fetch files from node outputs
:param outputs_dict: node outputs dict :param outputs_dict: node outputs dict
@ -646,15 +652,15 @@ class WorkflowCycleManage:
if not outputs_dict: if not outputs_dict:
return [] return []
files = [] files = [self._fetch_files_from_variable_value(output_value) for output_value in outputs_dict.values()]
for output_var, output_value in outputs_dict.items(): # Remove None
file_vars = self._fetch_files_from_variable_value(output_value) files = [file for file in files if file]
if file_vars: # Flatten list
files.extend(file_vars) files = [file for sublist in files for file in sublist]
return files return files
def _fetch_files_from_variable_value(self, value: Union[dict, list]) -> list[dict]: def _fetch_files_from_variable_value(self, value: Union[dict, list]) -> Sequence[Mapping[str, Any]]:
""" """
Fetch files from variable value Fetch files from variable value
:param value: variable value :param value: variable value
@ -666,17 +672,17 @@ class WorkflowCycleManage:
files = [] files = []
if isinstance(value, list): if isinstance(value, list):
for item in value: for item in value:
file_var = self._get_file_var_from_value(item) file = self._get_file_var_from_value(item)
if file_var: if file:
files.append(file_var) files.append(file)
elif isinstance(value, dict): elif isinstance(value, dict):
file_var = self._get_file_var_from_value(value) file = self._get_file_var_from_value(value)
if file_var: if file:
files.append(file_var) files.append(file)
return files return files
def _get_file_var_from_value(self, value: Union[dict, list]) -> Optional[dict]: def _get_file_var_from_value(self, value: Union[dict, list]) -> Mapping[str, Any] | None:
""" """
Get file var from value Get file var from value
:param value: variable value :param value: variable value
@ -685,14 +691,11 @@ class WorkflowCycleManage:
if not value: if not value:
return None return None
if isinstance(value, dict): if isinstance(value, dict) and value.get("dify_model_identity") == FILE_MODEL_IDENTITY:
if "__variant" in value and value["__variant"] == FileVar.__name__:
return value return value
elif isinstance(value, FileVar): elif isinstance(value, File):
return value.to_dict() return value.to_dict()
return None
def _refetch_workflow_run(self, workflow_run_id: str) -> WorkflowRun: def _refetch_workflow_run(self, workflow_run_id: str) -> WorkflowRun:
""" """
Refetch workflow run Refetch workflow run

@ -1,29 +0,0 @@
import enum
from typing import Any
from pydantic import BaseModel
class PromptMessageFileType(enum.Enum):
IMAGE = "image"
@staticmethod
def value_of(value):
for member in PromptMessageFileType:
if member.value == value:
return member
raise ValueError(f"No matching enum found for value '{value}'")
class PromptMessageFile(BaseModel):
type: PromptMessageFileType
data: Any = None
class ImagePromptMessageFile(PromptMessageFile):
class DETAIL(enum.Enum):
LOW = "low"
HIGH = "high"
type: PromptMessageFileType = PromptMessageFileType.IMAGE
detail: DETAIL = DETAIL.LOW

@ -6,7 +6,24 @@ from pydantic import BaseModel, ConfigDict, Field
from core.entities.parameter_entities import AppSelectorScope, CommonParameterType, ModelConfigScope from core.entities.parameter_entities import AppSelectorScope, CommonParameterType, ModelConfigScope
from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.entities.model_entities import ModelType
from core.tools.entities.common_entities import I18nObject from core.tools.entities.common_entities import I18nObject
from models.provider import ProviderQuotaType
class ProviderQuotaType(Enum):
PAID = "paid"
"""hosted paid quota"""
FREE = "free"
"""third-party free quota"""
TRIAL = "trial"
"""hosted trial quota"""
@staticmethod
def value_of(value):
for member in ProviderQuotaType:
if member.value == value:
return member
raise ValueError(f"No matching enum found for value '{value}'")
class QuotaUnit(Enum): class QuotaUnit(Enum):

@ -0,0 +1,19 @@
from .constants import FILE_MODEL_IDENTITY
from .enums import ArrayFileAttribute, FileAttribute, FileBelongsTo, FileTransferMethod, FileType
from .models import (
File,
FileExtraConfig,
ImageConfig,
)
__all__ = [
"FileType",
"FileExtraConfig",
"FileTransferMethod",
"FileBelongsTo",
"File",
"ImageConfig",
"FileAttribute",
"ArrayFileAttribute",
"FILE_MODEL_IDENTITY",
]

@ -0,0 +1 @@
FILE_MODEL_IDENTITY = "__dify__file__"

@ -0,0 +1,55 @@
from enum import Enum
class FileType(str, Enum):
IMAGE = "image"
DOCUMENT = "document"
AUDIO = "audio"
VIDEO = "video"
CUSTOM = "custom"
@staticmethod
def value_of(value):
for member in FileType:
if member.value == value:
return member
raise ValueError(f"No matching enum found for value '{value}'")
class FileTransferMethod(str, Enum):
REMOTE_URL = "remote_url"
LOCAL_FILE = "local_file"
TOOL_FILE = "tool_file"
@staticmethod
def value_of(value):
for member in FileTransferMethod:
if member.value == value:
return member
raise ValueError(f"No matching enum found for value '{value}'")
class FileBelongsTo(str, Enum):
USER = "user"
ASSISTANT = "assistant"
@staticmethod
def value_of(value):
for member in FileBelongsTo:
if member.value == value:
return member
raise ValueError(f"No matching enum found for value '{value}'")
class FileAttribute(str, Enum):
TYPE = "type"
SIZE = "size"
NAME = "name"
MIME_TYPE = "mime_type"
TRANSFER_METHOD = "transfer_method"
URL = "url"
EXTENSION = "extension"
class ArrayFileAttribute(str, Enum):
LENGTH = "length"

@ -0,0 +1,156 @@
import base64
from configs import dify_config
from core.file import file_repository
from core.helper import ssrf_proxy
from core.model_runtime.entities import AudioPromptMessageContent, ImagePromptMessageContent
from extensions.ext_database import db
from extensions.ext_storage import storage
from . import helpers
from .enums import FileAttribute
from .models import File, FileTransferMethod, FileType
from .tool_file_parser import ToolFileParser
def get_attr(*, file: File, attr: FileAttribute):
match attr:
case FileAttribute.TYPE:
return file.type.value
case FileAttribute.SIZE:
return file.size
case FileAttribute.NAME:
return file.filename
case FileAttribute.MIME_TYPE:
return file.mime_type
case FileAttribute.TRANSFER_METHOD:
return file.transfer_method.value
case FileAttribute.URL:
return file.remote_url
case FileAttribute.EXTENSION:
return file.extension
case _:
raise ValueError(f"Invalid file attribute: {attr}")
def to_prompt_message_content(f: File, /):
"""
Convert a File object to an ImagePromptMessageContent object.
This function takes a File object and converts it to an ImagePromptMessageContent
object, which can be used as a prompt for image-based AI models.
Args:
file (File): The File object to convert. Must be of type FileType.IMAGE.
Returns:
ImagePromptMessageContent: An object containing the image data and detail level.
Raises:
ValueError: If the file is not an image or if the file data is missing.
Note:
The detail level of the image prompt is determined by the file's extra_config.
If not specified, it defaults to ImagePromptMessageContent.DETAIL.LOW.
"""
match f.type:
case FileType.IMAGE:
if dify_config.MULTIMODAL_SEND_IMAGE_FORMAT == "url":
data = _to_url(f)
else:
data = _to_base64_data_string(f)
if f._extra_config and f._extra_config.image_config and f._extra_config.image_config.detail:
detail = f._extra_config.image_config.detail
else:
detail = ImagePromptMessageContent.DETAIL.LOW
return ImagePromptMessageContent(data=data, detail=detail)
case FileType.AUDIO:
encoded_string = _file_to_encoded_string(f)
if f.extension is None:
raise ValueError("Missing file extension")
return AudioPromptMessageContent(data=encoded_string, format=f.extension.lstrip("."))
case _:
raise ValueError(f"file type {f.type} is not supported")
def download(f: File, /):
upload_file = file_repository.get_upload_file(session=db.session(), file=f)
return _download_file_content(upload_file.key)
def _download_file_content(path: str, /):
"""
Download and return the contents of a file as bytes.
This function loads the file from storage and ensures it's in bytes format.
Args:
path (str): The path to the file in storage.
Returns:
bytes: The contents of the file as a bytes object.
Raises:
ValueError: If the loaded file is not a bytes object.
"""
data = storage.load(path, stream=False)
if not isinstance(data, bytes):
raise ValueError(f"file {path} is not a bytes object")
return data
def _get_encoded_string(f: File, /):
match f.transfer_method:
case FileTransferMethod.REMOTE_URL:
response = ssrf_proxy.get(f.remote_url)
response.raise_for_status()
content = response.content
encoded_string = base64.b64encode(content).decode("utf-8")
return encoded_string
case FileTransferMethod.LOCAL_FILE:
upload_file = file_repository.get_upload_file(session=db.session(), file=f)
data = _download_file_content(upload_file.key)
encoded_string = base64.b64encode(data).decode("utf-8")
return encoded_string
case FileTransferMethod.TOOL_FILE:
tool_file = file_repository.get_tool_file(session=db.session(), file=f)
data = _download_file_content(tool_file.file_key)
encoded_string = base64.b64encode(data).decode("utf-8")
return encoded_string
case _:
raise ValueError(f"Unsupported transfer method: {f.transfer_method}")
def _to_base64_data_string(f: File, /):
encoded_string = _get_encoded_string(f)
return f"data:{f.mime_type};base64,{encoded_string}"
def _file_to_encoded_string(f: File, /):
match f.type:
case FileType.IMAGE:
return _to_base64_data_string(f)
case FileType.AUDIO:
return _get_encoded_string(f)
case _:
raise ValueError(f"file type {f.type} is not supported")
def _to_url(f: File, /):
if f.transfer_method == FileTransferMethod.REMOTE_URL:
if f.remote_url is None:
raise ValueError("Missing file remote_url")
return f.remote_url
elif f.transfer_method == FileTransferMethod.LOCAL_FILE:
if f.related_id is None:
raise ValueError("Missing file related_id")
return helpers.get_signed_file_url(upload_file_id=f.related_id)
elif f.transfer_method == FileTransferMethod.TOOL_FILE:
# add sign url
if f.related_id is None or f.extension is None:
raise ValueError("Missing file related_id or extension")
return ToolFileParser.get_tool_file_manager().sign_file(tool_file_id=f.related_id, extension=f.extension)
else:
raise ValueError(f"Unsupported transfer method: {f.transfer_method}")

@ -1,145 +0,0 @@
import enum
from typing import Any, Optional
from pydantic import BaseModel
from core.file.tool_file_parser import ToolFileParser
from core.file.upload_file_parser import UploadFileParser
from core.model_runtime.entities.message_entities import ImagePromptMessageContent
from extensions.ext_database import db
class FileExtraConfig(BaseModel):
"""
File Upload Entity.
"""
image_config: Optional[dict[str, Any]] = None
class FileType(enum.Enum):
IMAGE = "image"
@staticmethod
def value_of(value):
for member in FileType:
if member.value == value:
return member
raise ValueError(f"No matching enum found for value '{value}'")
class FileTransferMethod(enum.Enum):
REMOTE_URL = "remote_url"
LOCAL_FILE = "local_file"
TOOL_FILE = "tool_file"
@staticmethod
def value_of(value):
for member in FileTransferMethod:
if member.value == value:
return member
raise ValueError(f"No matching enum found for value '{value}'")
class FileBelongsTo(enum.Enum):
USER = "user"
ASSISTANT = "assistant"
@staticmethod
def value_of(value):
for member in FileBelongsTo:
if member.value == value:
return member
raise ValueError(f"No matching enum found for value '{value}'")
class FileVar(BaseModel):
id: Optional[str] = None # message file id
tenant_id: str
type: FileType
transfer_method: FileTransferMethod
url: Optional[str] = None # remote url
related_id: Optional[str] = None
extra_config: Optional[FileExtraConfig] = None
filename: Optional[str] = None
extension: Optional[str] = None
mime_type: Optional[str] = None
def to_dict(self) -> dict:
return {
"__variant": self.__class__.__name__,
"tenant_id": self.tenant_id,
"type": self.type.value,
"transfer_method": self.transfer_method.value,
"url": self.preview_url,
"remote_url": self.url,
"related_id": self.related_id,
"filename": self.filename,
"extension": self.extension,
"mime_type": self.mime_type,
}
def to_markdown(self) -> str:
"""
Convert file to markdown
:return:
"""
preview_url = self.preview_url
if self.type == FileType.IMAGE:
text = f'![{self.filename or ""}]({preview_url})'
else:
text = f"[{self.filename or preview_url}]({preview_url})"
return text
@property
def data(self) -> Optional[str]:
"""
Get image data, file signed url or base64 data
depending on config MULTIMODAL_SEND_IMAGE_FORMAT
:return:
"""
return self._get_data()
@property
def preview_url(self) -> Optional[str]:
"""
Get signed preview url
:return:
"""
return self._get_data(force_url=True)
@property
def prompt_message_content(self) -> ImagePromptMessageContent:
if self.type == FileType.IMAGE:
image_config = self.extra_config.image_config
return ImagePromptMessageContent(
data=self.data,
detail=ImagePromptMessageContent.DETAIL.HIGH
if image_config.get("detail") == "high"
else ImagePromptMessageContent.DETAIL.LOW,
)
def _get_data(self, force_url: bool = False) -> Optional[str]:
from models.model import UploadFile
if self.type == FileType.IMAGE:
if self.transfer_method == FileTransferMethod.REMOTE_URL:
return self.url
elif self.transfer_method == FileTransferMethod.LOCAL_FILE:
upload_file = (
db.session.query(UploadFile)
.filter(UploadFile.id == self.related_id, UploadFile.tenant_id == self.tenant_id)
.first()
)
return UploadFileParser.get_image_data(upload_file=upload_file, force_url=force_url)
elif self.transfer_method == FileTransferMethod.TOOL_FILE:
extension = self.extension
# add sign url
return ToolFileParser.get_tool_file_manager().sign_file(
tool_file_id=self.related_id, extension=extension
)
return None

@ -0,0 +1,32 @@
from sqlalchemy import select
from sqlalchemy.orm import Session
from models import ToolFile, UploadFile
from .models import File
def get_upload_file(*, session: Session, file: File):
if file.related_id is None:
raise ValueError("Missing file related_id")
stmt = select(UploadFile).filter(
UploadFile.id == file.related_id,
UploadFile.tenant_id == file.tenant_id,
)
record = session.scalar(stmt)
if not record:
raise ValueError(f"upload file {file.related_id} not found")
return record
def get_tool_file(*, session: Session, file: File):
if file.related_id is None:
raise ValueError("Missing file related_id")
stmt = select(ToolFile).filter(
ToolFile.id == file.related_id,
ToolFile.tenant_id == file.tenant_id,
)
record = session.scalar(stmt)
if not record:
raise ValueError(f"tool file {file.related_id} not found")
return record

@ -0,0 +1,48 @@
import base64
import hashlib
import hmac
import os
import time
from configs import dify_config
def get_signed_file_url(upload_file_id: str) -> str:
url = f"{dify_config.FILES_URL}/files/{upload_file_id}/file-preview"
timestamp = str(int(time.time()))
nonce = os.urandom(16).hex()
key = dify_config.SECRET_KEY.encode()
msg = f"file-preview|{upload_file_id}|{timestamp}|{nonce}"
sign = hmac.new(key, msg.encode(), hashlib.sha256).digest()
encoded_sign = base64.urlsafe_b64encode(sign).decode()
return f"{url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}"
def verify_image_signature(*, upload_file_id: str, timestamp: str, nonce: str, sign: str) -> bool:
data_to_sign = f"image-preview|{upload_file_id}|{timestamp}|{nonce}"
secret_key = dify_config.SECRET_KEY.encode()
recalculated_sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
recalculated_encoded_sign = base64.urlsafe_b64encode(recalculated_sign).decode()
# verify signature
if sign != recalculated_encoded_sign:
return False
current_time = int(time.time())
return current_time - int(timestamp) <= dify_config.FILES_ACCESS_TIMEOUT
def verify_file_signature(*, upload_file_id: str, timestamp: str, nonce: str, sign: str) -> bool:
data_to_sign = f"file-preview|{upload_file_id}|{timestamp}|{nonce}"
secret_key = dify_config.SECRET_KEY.encode()
recalculated_sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
recalculated_encoded_sign = base64.urlsafe_b64encode(recalculated_sign).decode()
# verify signature
if sign != recalculated_encoded_sign:
return False
current_time = int(time.time())
return current_time - int(timestamp) <= dify_config.FILES_ACCESS_TIMEOUT

@ -1,243 +0,0 @@
import re
from collections.abc import Mapping, Sequence
from typing import Any, Union
from urllib.parse import parse_qs, urlparse
import requests
from core.file.file_obj import FileBelongsTo, FileExtraConfig, FileTransferMethod, FileType, FileVar
from extensions.ext_database import db
from models.account import Account
from models.model import EndUser, MessageFile, UploadFile
from services.file_service import IMAGE_EXTENSIONS
class MessageFileParser:
def __init__(self, tenant_id: str, app_id: str) -> None:
self.tenant_id = tenant_id
self.app_id = app_id
def validate_and_transform_files_arg(
self, files: Sequence[Mapping[str, Any]], file_extra_config: FileExtraConfig, user: Union[Account, EndUser]
) -> list[FileVar]:
"""
validate and transform files arg
:param files:
:param file_extra_config:
:param user:
:return:
"""
for file in files:
if not isinstance(file, dict):
raise ValueError("Invalid file format, must be dict")
if not file.get("type"):
raise ValueError("Missing file type")
FileType.value_of(file.get("type"))
if not file.get("transfer_method"):
raise ValueError("Missing file transfer method")
FileTransferMethod.value_of(file.get("transfer_method"))
if file.get("transfer_method") == FileTransferMethod.REMOTE_URL.value:
if not file.get("url"):
raise ValueError("Missing file url")
if not file.get("url").startswith("http"):
raise ValueError("Invalid file url")
if file.get("transfer_method") == FileTransferMethod.LOCAL_FILE.value and not file.get("upload_file_id"):
raise ValueError("Missing file upload_file_id")
if file.get("transform_method") == FileTransferMethod.TOOL_FILE.value and not file.get("tool_file_id"):
raise ValueError("Missing file tool_file_id")
# transform files to file objs
type_file_objs = self._to_file_objs(files, file_extra_config)
# validate files
new_files = []
for file_type, file_objs in type_file_objs.items():
if file_type == FileType.IMAGE:
# parse and validate files
image_config = file_extra_config.image_config
# check if image file feature is enabled
if not image_config:
continue
# Validate number of files
if len(files) > image_config["number_limits"]:
raise ValueError(f"Number of image files exceeds the maximum limit {image_config['number_limits']}")
for file_obj in file_objs:
# Validate transfer method
if file_obj.transfer_method.value not in image_config["transfer_methods"]:
raise ValueError(f"Invalid transfer method: {file_obj.transfer_method.value}")
# Validate file type
if file_obj.type != FileType.IMAGE:
raise ValueError(f"Invalid file type: {file_obj.type}")
if file_obj.transfer_method == FileTransferMethod.REMOTE_URL:
# check remote url valid and is image
result, error = self._check_image_remote_url(file_obj.url)
if result is False:
raise ValueError(error)
elif file_obj.transfer_method == FileTransferMethod.LOCAL_FILE:
# get upload file from upload_file_id
upload_file = (
db.session.query(UploadFile)
.filter(
UploadFile.id == file_obj.related_id,
UploadFile.tenant_id == self.tenant_id,
UploadFile.created_by == user.id,
UploadFile.created_by_role == ("account" if isinstance(user, Account) else "end_user"),
UploadFile.extension.in_(IMAGE_EXTENSIONS),
)
.first()
)
# check upload file is belong to tenant and user
if not upload_file:
raise ValueError("Invalid upload file")
new_files.append(file_obj)
# return all file objs
return new_files
def transform_message_files(self, files: list[MessageFile], file_extra_config: FileExtraConfig):
"""
transform message files
:param files:
:param file_extra_config:
:return:
"""
# transform files to file objs
type_file_objs = self._to_file_objs(files, file_extra_config)
# return all file objs
return [file_obj for file_objs in type_file_objs.values() for file_obj in file_objs]
def _to_file_objs(
self, files: list[Union[dict, MessageFile]], file_extra_config: FileExtraConfig
) -> dict[FileType, list[FileVar]]:
"""
transform files to file objs
:param files:
:param file_extra_config:
:return:
"""
type_file_objs: dict[FileType, list[FileVar]] = {
# Currently only support image
FileType.IMAGE: []
}
if not files:
return type_file_objs
# group by file type and convert file args or message files to FileObj
for file in files:
if isinstance(file, MessageFile):
if file.belongs_to == FileBelongsTo.ASSISTANT.value:
continue
file_obj = self._to_file_obj(file, file_extra_config)
if file_obj.type not in type_file_objs:
continue
type_file_objs[file_obj.type].append(file_obj)
return type_file_objs
def _to_file_obj(self, file: Union[dict, MessageFile], file_extra_config: FileExtraConfig):
"""
transform file to file obj
:param file:
:return:
"""
if isinstance(file, dict):
transfer_method = FileTransferMethod.value_of(file.get("transfer_method"))
if transfer_method != FileTransferMethod.TOOL_FILE:
return FileVar(
tenant_id=self.tenant_id,
type=FileType.value_of(file.get("type")),
transfer_method=transfer_method,
url=file.get("url") if transfer_method == FileTransferMethod.REMOTE_URL else None,
related_id=file.get("upload_file_id") if transfer_method == FileTransferMethod.LOCAL_FILE else None,
extra_config=file_extra_config,
)
return FileVar(
tenant_id=self.tenant_id,
type=FileType.value_of(file.get("type")),
transfer_method=transfer_method,
url=None,
related_id=file.get("tool_file_id"),
extra_config=file_extra_config,
)
else:
return FileVar(
id=file.id,
tenant_id=self.tenant_id,
type=FileType.value_of(file.type),
transfer_method=FileTransferMethod.value_of(file.transfer_method),
url=file.url,
related_id=file.upload_file_id or None,
extra_config=file_extra_config,
)
def _check_image_remote_url(self, url):
try:
headers = {
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko)"
" Chrome/91.0.4472.124 Safari/537.36"
}
def is_s3_presigned_url(url):
try:
parsed_url = urlparse(url)
if "amazonaws.com" not in parsed_url.netloc:
return False
query_params = parse_qs(parsed_url.query)
def check_presign_v2(query_params):
required_params = ["Signature", "Expires"]
for param in required_params:
if param not in query_params:
return False
if not query_params["Expires"][0].isdigit():
return False
signature = query_params["Signature"][0]
if not re.match(r"^[A-Za-z0-9+/]+={0,2}$", signature):
return False
return True
def check_presign_v4(query_params):
required_params = ["X-Amz-Signature", "X-Amz-Expires"]
for param in required_params:
if param not in query_params:
return False
if not query_params["X-Amz-Expires"][0].isdigit():
return False
signature = query_params["X-Amz-Signature"][0]
if not re.match(r"^[A-Za-z0-9+/]+={0,2}$", signature):
return False
return True
return check_presign_v4(query_params) or check_presign_v2(query_params)
except Exception:
return False
if is_s3_presigned_url(url):
response = requests.get(url, headers=headers, allow_redirects=True)
if response.status_code in {200, 304}:
return True, ""
response = requests.head(url, headers=headers, allow_redirects=True)
if response.status_code in {200, 304}:
return True, ""
else:
return False, "URL does not exist."
except requests.RequestException as e:
return False, f"Error checking URL: {e}"

@ -0,0 +1,140 @@
from collections.abc import Mapping, Sequence
from typing import Optional
from pydantic import BaseModel, Field, model_validator
from core.model_runtime.entities.message_entities import ImagePromptMessageContent
from . import helpers
from .constants import FILE_MODEL_IDENTITY
from .enums import FileTransferMethod, FileType
from .tool_file_parser import ToolFileParser
class ImageConfig(BaseModel):
"""
NOTE: This part of validation is deprecated, but still used in app features "Image Upload".
"""
number_limits: int = 0
transfer_methods: Sequence[FileTransferMethod] = Field(default_factory=list)
detail: ImagePromptMessageContent.DETAIL | None = None
class FileExtraConfig(BaseModel):
"""
File Upload Entity.
"""
image_config: Optional[ImageConfig] = None
allowed_file_types: Sequence[FileType] = Field(default_factory=list)
allowed_extensions: Sequence[str] = Field(default_factory=list)
allowed_upload_methods: Sequence[FileTransferMethod] = Field(default_factory=list)
number_limits: int = 0
class File(BaseModel):
dify_model_identity: str = FILE_MODEL_IDENTITY
id: Optional[str] = None # message file id
tenant_id: str
type: FileType
transfer_method: FileTransferMethod
remote_url: Optional[str] = None # remote url
related_id: Optional[str] = None
filename: Optional[str] = None
extension: Optional[str] = Field(default=None, description="File extension, should contains dot")
mime_type: Optional[str] = None
size: int = -1
_extra_config: FileExtraConfig | None = None
def to_dict(self) -> Mapping[str, str | int | None]:
data = self.model_dump(mode="json")
return {
**data,
"url": self.generate_url(),
}
@property
def markdown(self) -> str:
url = self.generate_url()
if self.type == FileType.IMAGE:
text = f'![{self.filename or ""}]({url})'
else:
text = f"[{self.filename or url}]({url})"
return text
def generate_url(self) -> Optional[str]:
if self.type == FileType.IMAGE:
if self.transfer_method == FileTransferMethod.REMOTE_URL:
return self.remote_url
elif self.transfer_method == FileTransferMethod.LOCAL_FILE:
if self.related_id is None:
raise ValueError("Missing file related_id")
return helpers.get_signed_file_url(upload_file_id=self.related_id)
elif self.transfer_method == FileTransferMethod.TOOL_FILE:
assert self.related_id is not None
assert self.extension is not None
return ToolFileParser.get_tool_file_manager().sign_file(
tool_file_id=self.related_id, extension=self.extension
)
else:
if self.transfer_method == FileTransferMethod.REMOTE_URL:
return self.remote_url
elif self.transfer_method == FileTransferMethod.LOCAL_FILE:
if self.related_id is None:
raise ValueError("Missing file related_id")
return helpers.get_signed_file_url(upload_file_id=self.related_id)
elif self.transfer_method == FileTransferMethod.TOOL_FILE:
assert self.related_id is not None
assert self.extension is not None
return ToolFileParser.get_tool_file_manager().sign_file(
tool_file_id=self.related_id, extension=self.extension
)
@model_validator(mode="after")
def validate_after(self):
match self.transfer_method:
case FileTransferMethod.REMOTE_URL:
if not self.remote_url:
raise ValueError("Missing file url")
if not isinstance(self.remote_url, str) or not self.remote_url.startswith("http"):
raise ValueError("Invalid file url")
case FileTransferMethod.LOCAL_FILE:
if not self.related_id:
raise ValueError("Missing file related_id")
case FileTransferMethod.TOOL_FILE:
if not self.related_id:
raise ValueError("Missing file related_id")
# Validate the extra config.
if not self._extra_config:
return self
if self._extra_config.allowed_file_types:
if self.type not in self._extra_config.allowed_file_types and self.type != FileType.CUSTOM:
raise ValueError(f"Invalid file type: {self.type}")
if self._extra_config.allowed_extensions and self.extension not in self._extra_config.allowed_extensions:
raise ValueError(f"Invalid file extension: {self.extension}")
if (
self._extra_config.allowed_upload_methods
and self.transfer_method not in self._extra_config.allowed_upload_methods
):
raise ValueError(f"Invalid transfer method: {self.transfer_method}")
match self.type:
case FileType.IMAGE:
# NOTE: This part of validation is deprecated, but still used in app features "Image Upload".
if not self._extra_config.image_config:
return self
# TODO: skip check if transfer_methods is empty, because many test cases are not setting this field
if (
self._extra_config.image_config.transfer_methods
and self.transfer_method not in self._extra_config.image_config.transfer_methods
):
raise ValueError(f"Invalid transfer method: {self.transfer_method}")
return self

@ -13,8 +13,11 @@ SSRF_PROXY_HTTP_URL = os.getenv("SSRF_PROXY_HTTP_URL", "")
SSRF_PROXY_HTTPS_URL = os.getenv("SSRF_PROXY_HTTPS_URL", "") SSRF_PROXY_HTTPS_URL = os.getenv("SSRF_PROXY_HTTPS_URL", "")
SSRF_DEFAULT_MAX_RETRIES = int(os.getenv("SSRF_DEFAULT_MAX_RETRIES", "3")) SSRF_DEFAULT_MAX_RETRIES = int(os.getenv("SSRF_DEFAULT_MAX_RETRIES", "3"))
proxies = ( proxy_mounts = (
{"http://": SSRF_PROXY_HTTP_URL, "https://": SSRF_PROXY_HTTPS_URL} {
"http://": httpx.HTTPTransport(proxy=SSRF_PROXY_HTTP_URL),
"https://": httpx.HTTPTransport(proxy=SSRF_PROXY_HTTPS_URL),
}
if SSRF_PROXY_HTTP_URL and SSRF_PROXY_HTTPS_URL if SSRF_PROXY_HTTP_URL and SSRF_PROXY_HTTPS_URL
else None else None
) )
@ -33,11 +36,14 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
while retries <= max_retries: while retries <= max_retries:
try: try:
if SSRF_PROXY_ALL_URL: if SSRF_PROXY_ALL_URL:
response = httpx.request(method=method, url=url, proxy=SSRF_PROXY_ALL_URL, **kwargs) with httpx.Client(proxy=SSRF_PROXY_ALL_URL) as client:
elif proxies: response = client.request(method=method, url=url, **kwargs)
response = httpx.request(method=method, url=url, proxies=proxies, **kwargs) elif proxy_mounts:
with httpx.Client(mounts=proxy_mounts) as client:
response = client.request(method=method, url=url, **kwargs)
else: else:
response = httpx.request(method=method, url=url, **kwargs) with httpx.Client() as client:
response = client.request(method=method, url=url, **kwargs)
if response.status_code not in STATUS_FORCELIST: if response.status_code not in STATUS_FORCELIST:
return response return response

@ -3,9 +3,8 @@ from typing import Optional
from flask import Config, Flask from flask import Config, Flask
from pydantic import BaseModel from pydantic import BaseModel
from core.entities.provider_entities import QuotaUnit, RestrictModel from core.entities.provider_entities import ProviderQuotaType, QuotaUnit, RestrictModel
from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.entities.model_entities import ModelType
from models.provider import ProviderQuotaType
class HostingQuota(BaseModel): class HostingQuota(BaseModel):

@ -1,18 +1,20 @@
from typing import Optional from typing import Optional
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
from core.file.message_file_parser import MessageFileParser from core.file import file_manager
from core.model_manager import ModelInstance from core.model_manager import ModelInstance
from core.model_runtime.entities.message_entities import ( from core.model_runtime.entities import (
AssistantPromptMessage, AssistantPromptMessage,
ImagePromptMessageContent, ImagePromptMessageContent,
PromptMessage, PromptMessage,
PromptMessageContent,
PromptMessageRole, PromptMessageRole,
TextPromptMessageContent, TextPromptMessageContent,
UserPromptMessage, UserPromptMessage,
) )
from core.prompt.utils.extract_thread_messages import extract_thread_messages from core.prompt.utils.extract_thread_messages import extract_thread_messages
from extensions.ext_database import db from extensions.ext_database import db
from factories import file_factory
from models.model import AppMode, Conversation, Message, MessageFile from models.model import AppMode, Conversation, Message, MessageFile
from models.workflow import WorkflowRun from models.workflow import WorkflowRun
@ -65,13 +67,12 @@ class TokenBufferMemory:
messages = list(reversed(thread_messages)) messages = list(reversed(thread_messages))
message_file_parser = MessageFileParser(tenant_id=app_record.tenant_id, app_id=app_record.id)
prompt_messages = [] prompt_messages = []
for message in messages: for message in messages:
files = db.session.query(MessageFile).filter(MessageFile.message_id == message.id).all() files = db.session.query(MessageFile).filter(MessageFile.message_id == message.id).all()
if files: if files:
file_extra_config = None file_extra_config = None
if self.conversation.mode not in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}: if self.conversation.mode not in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
file_extra_config = FileUploadConfigManager.convert(self.conversation.model_config) file_extra_config = FileUploadConfigManager.convert(self.conversation.model_config)
else: else:
if message.workflow_run_id: if message.workflow_run_id:
@ -84,17 +85,21 @@ class TokenBufferMemory:
workflow_run.workflow.features_dict, is_vision=False workflow_run.workflow.features_dict, is_vision=False
) )
if file_extra_config: if file_extra_config and app_record:
file_objs = message_file_parser.transform_message_files(files, file_extra_config) file_objs = file_factory.build_from_message_files(
message_files=files, tenant_id=app_record.tenant_id, config=file_extra_config
)
else: else:
file_objs = [] file_objs = []
if not file_objs: if not file_objs:
prompt_messages.append(UserPromptMessage(content=message.query)) prompt_messages.append(UserPromptMessage(content=message.query))
else: else:
prompt_message_contents = [TextPromptMessageContent(data=message.query)] prompt_message_contents: list[PromptMessageContent] = []
prompt_message_contents.append(TextPromptMessageContent(data=message.query))
for file_obj in file_objs: for file_obj in file_objs:
prompt_message_contents.append(file_obj.prompt_message_content) prompt_message = file_manager.to_prompt_message_content(file_obj)
prompt_message_contents.append(prompt_message)
prompt_messages.append(UserPromptMessage(content=prompt_message_contents)) prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
else: else:

@ -1,9 +1,9 @@
import logging import logging
import os import os
from collections.abc import Callable, Generator, Sequence from collections.abc import Callable, Generator, Iterable, Sequence
from typing import IO, Literal, Optional, Union, cast, overload from typing import IO, Any, Literal, Optional, Union, cast, overload
from core.embedding.embedding_constant import EmbeddingInputType from core.entities.embedding_type import EmbeddingInputType
from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle
from core.entities.provider_entities import ModelLoadBalancingConfiguration from core.entities.provider_entities import ModelLoadBalancingConfiguration
from core.errors.error import ProviderTokenNotInitError from core.errors.error import ProviderTokenNotInitError
@ -310,9 +310,7 @@ class ModelInstance:
user=user, user=user,
) )
def invoke_tts( def invoke_tts(self, content_text: str, tenant_id: str, voice: str, user: Optional[str] = None) -> Iterable[bytes]:
self, content_text: str, tenant_id: str, voice: str, user: Optional[str] = None
) -> Generator[bytes, None, None]:
""" """
Invoke large language tts model Invoke large language tts model
@ -336,7 +334,7 @@ class ModelInstance:
voice=voice, voice=voice,
) )
def _round_robin_invoke(self, function: Callable, *args, **kwargs): def _round_robin_invoke(self, function: Callable[..., Any], *args, **kwargs):
""" """
Round-robin invoke Round-robin invoke
:param function: function to invoke :param function: function to invoke

@ -0,0 +1,38 @@
from .llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
from .message_entities import (
AssistantPromptMessage,
AudioPromptMessageContent,
ImagePromptMessageContent,
PromptMessage,
PromptMessageContent,
PromptMessageContentType,
PromptMessageRole,
PromptMessageTool,
SystemPromptMessage,
TextPromptMessageContent,
ToolPromptMessage,
UserPromptMessage,
)
from .model_entities import ModelPropertyKey
__all__ = [
"ImagePromptMessageContent",
"PromptMessage",
"PromptMessageRole",
"LLMUsage",
"ModelPropertyKey",
"AssistantPromptMessage",
"PromptMessage",
"PromptMessageContent",
"PromptMessageRole",
"SystemPromptMessage",
"TextPromptMessageContent",
"UserPromptMessage",
"PromptMessageTool",
"ToolPromptMessage",
"PromptMessageContentType",
"LLMResult",
"LLMResultChunk",
"LLMResultChunkDelta",
"AudioPromptMessageContent",
]

@ -2,7 +2,7 @@ from abc import ABC
from enum import Enum from enum import Enum
from typing import Optional from typing import Optional
from pydantic import BaseModel, field_validator from pydantic import BaseModel, Field, field_validator
class PromptMessageRole(Enum): class PromptMessageRole(Enum):
@ -55,6 +55,7 @@ class PromptMessageContentType(Enum):
TEXT = "text" TEXT = "text"
IMAGE = "image" IMAGE = "image"
AUDIO = "audio"
class PromptMessageContent(BaseModel): class PromptMessageContent(BaseModel):
@ -74,12 +75,18 @@ class TextPromptMessageContent(PromptMessageContent):
type: PromptMessageContentType = PromptMessageContentType.TEXT type: PromptMessageContentType = PromptMessageContentType.TEXT
class AudioPromptMessageContent(PromptMessageContent):
type: PromptMessageContentType = PromptMessageContentType.AUDIO
data: str = Field(..., description="Base64 encoded audio data")
format: str = Field(..., description="Audio format")
class ImagePromptMessageContent(PromptMessageContent): class ImagePromptMessageContent(PromptMessageContent):
""" """
Model class for image prompt message content. Model class for image prompt message content.
""" """
class DETAIL(Enum): class DETAIL(str, Enum):
LOW = "low" LOW = "low"
HIGH = "high" HIGH = "high"

@ -1,11 +1,11 @@
import logging import logging
import os
import time import time
from collections.abc import Generator from collections.abc import Generator
from typing import Optional, Union from typing import Optional, Union
from pydantic import ConfigDict from pydantic import ConfigDict
from configs import dify_config
from core.model_runtime.callbacks.base_callback import Callback from core.model_runtime.callbacks.base_callback import Callback
from core.model_runtime.callbacks.logging_callback import LoggingCallback from core.model_runtime.callbacks.logging_callback import LoggingCallback
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMUsage from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMUsage
@ -68,7 +68,7 @@ class LargeLanguageModel(AIModel):
callbacks = callbacks or [] callbacks = callbacks or []
if bool(os.environ.get("DEBUG", "False").lower() == "true"): if dify_config.DEBUG:
callbacks.append(LoggingCallback()) callbacks.append(LoggingCallback())
# trigger before invoke callbacks # trigger before invoke callbacks

@ -2,7 +2,7 @@ from typing import Optional
from pydantic import ConfigDict from pydantic import ConfigDict
from core.embedding.embedding_constant import EmbeddingInputType from core.entities.embedding_type import EmbeddingInputType
from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
from core.model_runtime.model_providers.__base.ai_model import AIModel from core.model_runtime.model_providers.__base.ai_model import AIModel

@ -1,4 +1,5 @@
import logging import logging
from collections.abc import Iterable
from typing import Optional from typing import Optional
from pydantic import ConfigDict from pydantic import ConfigDict
@ -21,8 +22,14 @@ class TTSModel(AIModel):
model_config = ConfigDict(protected_namespaces=()) model_config = ConfigDict(protected_namespaces=())
def invoke( def invoke(
self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str, user: Optional[str] = None self,
): model: str,
tenant_id: str,
credentials: dict,
content_text: str,
voice: str,
user: Optional[str] = None,
) -> Iterable[bytes]:
""" """
Invoke large language model Invoke large language model
@ -52,12 +59,12 @@ class TTSModel(AIModel):
def get_tts_model_voices(self, model: str, credentials: dict, language: Optional[str] = None) -> list[dict]: def get_tts_model_voices(self, model: str, credentials: dict, language: Optional[str] = None) -> list[dict]:
""" """
Get voice for given tts model voices Retrieves the list of voices supported by a given text-to-speech (TTS) model.
:param language: tts language :param language: The language for which the voices are requested.
:param model: model name :param model: The name of the TTS model.
:param credentials: model credentials :param credentials: The credentials required to access the TTS model.
:return: voices lists :return: A list of voices supported by the TTS model.
""" """
plugin_model_manager = PluginModelManager() plugin_model_manager = PluginModelManager()
return plugin_model_manager.get_tts_model_voices( return plugin_model_manager.get_tts_model_voices(

@ -358,8 +358,8 @@ class TraceTask:
workflow_run_id = workflow_run.id workflow_run_id = workflow_run.id
workflow_run_elapsed_time = workflow_run.elapsed_time workflow_run_elapsed_time = workflow_run.elapsed_time
workflow_run_status = workflow_run.status workflow_run_status = workflow_run.status
workflow_run_inputs = json.loads(workflow_run.inputs) if workflow_run.inputs else {} workflow_run_inputs = workflow_run.inputs_dict
workflow_run_outputs = json.loads(workflow_run.outputs) if workflow_run.outputs else {} workflow_run_outputs = workflow_run.outputs_dict
workflow_run_version = workflow_run.version workflow_run_version = workflow_run.version
error = workflow_run.error or "" error = workflow_run.error or ""

@ -21,7 +21,7 @@ from core.plugin.entities.request import (
) )
from core.tools.entities.tool_entities import ToolProviderType from core.tools.entities.tool_entities import ToolProviderType
from core.tools.utils.model_invocation_utils import ModelInvocationUtils from core.tools.utils.model_invocation_utils import ModelInvocationUtils
from core.workflow.nodes.llm.llm_node import LLMNode from core.workflow.nodes.llm.node import LLMNode
from models.account import Tenant from models.account import Tenant

@ -1,5 +1,5 @@
from core.plugin.backwards_invocation.base import BaseBackwardsInvocation from core.plugin.backwards_invocation.base import BaseBackwardsInvocation
from core.workflow.entities.node_entities import NodeType from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.parameter_extractor.entities import ( from core.workflow.nodes.parameter_extractor.entities import (
ModelConfig as ParameterExtractorModelConfig, ModelConfig as ParameterExtractorModelConfig,
) )

@ -1,12 +1,15 @@
from typing import Optional, Union from collections.abc import Sequence
from typing import Optional
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.file.file_obj import FileVar from core.file import file_manager
from core.file.models import File
from core.helper.code_executor.jinja2.jinja2_formatter import Jinja2Formatter from core.helper.code_executor.jinja2.jinja2_formatter import Jinja2Formatter
from core.memory.token_buffer_memory import TokenBufferMemory from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_runtime.entities.message_entities import ( from core.model_runtime.entities import (
AssistantPromptMessage, AssistantPromptMessage,
PromptMessage, PromptMessage,
PromptMessageContent,
PromptMessageRole, PromptMessageRole,
SystemPromptMessage, SystemPromptMessage,
TextPromptMessageContent, TextPromptMessageContent,
@ -14,8 +17,8 @@ from core.model_runtime.entities.message_entities import (
) )
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig
from core.prompt.prompt_transform import PromptTransform from core.prompt.prompt_transform import PromptTransform
from core.prompt.simple_prompt_transform import ModelMode
from core.prompt.utils.prompt_template_parser import PromptTemplateParser from core.prompt.utils.prompt_template_parser import PromptTemplateParser
from core.workflow.entities.variable_pool import VariablePool
class AdvancedPromptTransform(PromptTransform): class AdvancedPromptTransform(PromptTransform):
@ -28,22 +31,19 @@ class AdvancedPromptTransform(PromptTransform):
def get_prompt( def get_prompt(
self, self,
prompt_template: Union[list[ChatModelMessage], CompletionModelPromptTemplate], *,
inputs: dict, prompt_template: Sequence[ChatModelMessage] | CompletionModelPromptTemplate,
inputs: dict[str, str],
query: str, query: str,
files: list[FileVar], files: Sequence[File],
context: Optional[str], context: Optional[str],
memory_config: Optional[MemoryConfig], memory_config: Optional[MemoryConfig],
memory: Optional[TokenBufferMemory], memory: Optional[TokenBufferMemory],
model_config: ModelConfigWithCredentialsEntity, model_config: ModelConfigWithCredentialsEntity,
query_prompt_template: Optional[str] = None,
) -> list[PromptMessage]: ) -> list[PromptMessage]:
inputs = {key: str(value) for key, value in inputs.items()}
prompt_messages = [] prompt_messages = []
model_mode = ModelMode.value_of(model_config.mode) if isinstance(prompt_template, CompletionModelPromptTemplate):
if model_mode == ModelMode.COMPLETION:
prompt_messages = self._get_completion_model_prompt_messages( prompt_messages = self._get_completion_model_prompt_messages(
prompt_template=prompt_template, prompt_template=prompt_template,
inputs=inputs, inputs=inputs,
@ -54,12 +54,11 @@ class AdvancedPromptTransform(PromptTransform):
memory=memory, memory=memory,
model_config=model_config, model_config=model_config,
) )
elif model_mode == ModelMode.CHAT: elif isinstance(prompt_template, list) and all(isinstance(item, ChatModelMessage) for item in prompt_template):
prompt_messages = self._get_chat_model_prompt_messages( prompt_messages = self._get_chat_model_prompt_messages(
prompt_template=prompt_template, prompt_template=prompt_template,
inputs=inputs, inputs=inputs,
query=query, query=query,
query_prompt_template=query_prompt_template,
files=files, files=files,
context=context, context=context,
memory_config=memory_config, memory_config=memory_config,
@ -74,7 +73,7 @@ class AdvancedPromptTransform(PromptTransform):
prompt_template: CompletionModelPromptTemplate, prompt_template: CompletionModelPromptTemplate,
inputs: dict, inputs: dict,
query: Optional[str], query: Optional[str],
files: list[FileVar], files: Sequence[File],
context: Optional[str], context: Optional[str],
memory_config: Optional[MemoryConfig], memory_config: Optional[MemoryConfig],
memory: Optional[TokenBufferMemory], memory: Optional[TokenBufferMemory],
@ -88,10 +87,10 @@ class AdvancedPromptTransform(PromptTransform):
prompt_messages = [] prompt_messages = []
if prompt_template.edition_type == "basic" or not prompt_template.edition_type: if prompt_template.edition_type == "basic" or not prompt_template.edition_type:
prompt_template = PromptTemplateParser(template=raw_prompt, with_variable_tmpl=self.with_variable_tmpl) parser = PromptTemplateParser(template=raw_prompt, with_variable_tmpl=self.with_variable_tmpl)
prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs} prompt_inputs = {k: inputs[k] for k in parser.variable_keys if k in inputs}
prompt_inputs = self._set_context_variable(context, prompt_template, prompt_inputs) prompt_inputs = self._set_context_variable(context, parser, prompt_inputs)
if memory and memory_config: if memory and memory_config:
role_prefix = memory_config.role_prefix role_prefix = memory_config.role_prefix
@ -100,15 +99,15 @@ class AdvancedPromptTransform(PromptTransform):
memory_config=memory_config, memory_config=memory_config,
raw_prompt=raw_prompt, raw_prompt=raw_prompt,
role_prefix=role_prefix, role_prefix=role_prefix,
prompt_template=prompt_template, parser=parser,
prompt_inputs=prompt_inputs, prompt_inputs=prompt_inputs,
model_config=model_config, model_config=model_config,
) )
if query: if query:
prompt_inputs = self._set_query_variable(query, prompt_template, prompt_inputs) prompt_inputs = self._set_query_variable(query, parser, prompt_inputs)
prompt = prompt_template.format(prompt_inputs) prompt = parser.format(prompt_inputs)
else: else:
prompt = raw_prompt prompt = raw_prompt
prompt_inputs = inputs prompt_inputs = inputs
@ -116,9 +115,10 @@ class AdvancedPromptTransform(PromptTransform):
prompt = Jinja2Formatter.format(prompt, prompt_inputs) prompt = Jinja2Formatter.format(prompt, prompt_inputs)
if files: if files:
prompt_message_contents = [TextPromptMessageContent(data=prompt)] prompt_message_contents: list[PromptMessageContent] = []
prompt_message_contents.append(TextPromptMessageContent(data=prompt))
for file in files: for file in files:
prompt_message_contents.append(file.prompt_message_content) prompt_message_contents.append(file_manager.to_prompt_message_content(file))
prompt_messages.append(UserPromptMessage(content=prompt_message_contents)) prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
else: else:
@ -131,35 +131,38 @@ class AdvancedPromptTransform(PromptTransform):
prompt_template: list[ChatModelMessage], prompt_template: list[ChatModelMessage],
inputs: dict, inputs: dict,
query: Optional[str], query: Optional[str],
files: list[FileVar], files: Sequence[File],
context: Optional[str], context: Optional[str],
memory_config: Optional[MemoryConfig], memory_config: Optional[MemoryConfig],
memory: Optional[TokenBufferMemory], memory: Optional[TokenBufferMemory],
model_config: ModelConfigWithCredentialsEntity, model_config: ModelConfigWithCredentialsEntity,
query_prompt_template: Optional[str] = None,
) -> list[PromptMessage]: ) -> list[PromptMessage]:
""" """
Get chat model prompt messages. Get chat model prompt messages.
""" """
raw_prompt_list = prompt_template
prompt_messages = [] prompt_messages = []
for prompt_item in prompt_template:
for prompt_item in raw_prompt_list:
raw_prompt = prompt_item.text raw_prompt = prompt_item.text
if prompt_item.edition_type == "basic" or not prompt_item.edition_type: if prompt_item.edition_type == "basic" or not prompt_item.edition_type:
prompt_template = PromptTemplateParser(template=raw_prompt, with_variable_tmpl=self.with_variable_tmpl) if self.with_variable_tmpl:
prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs} vp = VariablePool()
for k, v in inputs.items():
prompt_inputs = self._set_context_variable(context, prompt_template, prompt_inputs) if k.startswith("#"):
vp.add(k[1:-1].split("."), v)
prompt = prompt_template.format(prompt_inputs) raw_prompt = raw_prompt.replace("{{#context#}}", context or "")
prompt = vp.convert_template(raw_prompt).text
else:
parser = PromptTemplateParser(template=raw_prompt, with_variable_tmpl=self.with_variable_tmpl)
prompt_inputs = {k: inputs[k] for k in parser.variable_keys if k in inputs}
prompt_inputs = self._set_context_variable(
context=context, parser=parser, prompt_inputs=prompt_inputs
)
prompt = parser.format(prompt_inputs)
elif prompt_item.edition_type == "jinja2": elif prompt_item.edition_type == "jinja2":
prompt = raw_prompt prompt = raw_prompt
prompt_inputs = inputs prompt_inputs = inputs
prompt = Jinja2Formatter.format(template=prompt, inputs=prompt_inputs)
prompt = Jinja2Formatter.format(prompt, prompt_inputs)
else: else:
raise ValueError(f"Invalid edition type: {prompt_item.edition_type}") raise ValueError(f"Invalid edition type: {prompt_item.edition_type}")
@ -170,25 +173,25 @@ class AdvancedPromptTransform(PromptTransform):
elif prompt_item.role == PromptMessageRole.ASSISTANT: elif prompt_item.role == PromptMessageRole.ASSISTANT:
prompt_messages.append(AssistantPromptMessage(content=prompt)) prompt_messages.append(AssistantPromptMessage(content=prompt))
if query and query_prompt_template: if query and memory_config and memory_config.query_prompt_template:
prompt_template = PromptTemplateParser( parser = PromptTemplateParser(
template=query_prompt_template, with_variable_tmpl=self.with_variable_tmpl template=memory_config.query_prompt_template, with_variable_tmpl=self.with_variable_tmpl
) )
prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs} prompt_inputs = {k: inputs[k] for k in parser.variable_keys if k in inputs}
prompt_inputs["#sys.query#"] = query prompt_inputs["#sys.query#"] = query
prompt_inputs = self._set_context_variable(context, prompt_template, prompt_inputs) prompt_inputs = self._set_context_variable(context, parser, prompt_inputs)
query = prompt_template.format(prompt_inputs) query = parser.format(prompt_inputs)
if memory and memory_config: if memory and memory_config:
prompt_messages = self._append_chat_histories(memory, memory_config, prompt_messages, model_config) prompt_messages = self._append_chat_histories(memory, memory_config, prompt_messages, model_config)
if files: if files and query is not None:
prompt_message_contents = [TextPromptMessageContent(data=query)] prompt_message_contents: list[PromptMessageContent] = []
prompt_message_contents.append(TextPromptMessageContent(data=query))
for file in files: for file in files:
prompt_message_contents.append(file.prompt_message_content) prompt_message_contents.append(file_manager.to_prompt_message_content(file))
prompt_messages.append(UserPromptMessage(content=prompt_message_contents)) prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
else: else:
prompt_messages.append(UserPromptMessage(content=query)) prompt_messages.append(UserPromptMessage(content=query))
@ -200,19 +203,19 @@ class AdvancedPromptTransform(PromptTransform):
# get last user message content and add files # get last user message content and add files
prompt_message_contents = [TextPromptMessageContent(data=last_message.content)] prompt_message_contents = [TextPromptMessageContent(data=last_message.content)]
for file in files: for file in files:
prompt_message_contents.append(file.prompt_message_content) prompt_message_contents.append(file_manager.to_prompt_message_content(file))
last_message.content = prompt_message_contents last_message.content = prompt_message_contents
else: else:
prompt_message_contents = [TextPromptMessageContent(data="")] # not for query prompt_message_contents = [TextPromptMessageContent(data="")] # not for query
for file in files: for file in files:
prompt_message_contents.append(file.prompt_message_content) prompt_message_contents.append(file_manager.to_prompt_message_content(file))
prompt_messages.append(UserPromptMessage(content=prompt_message_contents)) prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
else: else:
prompt_message_contents = [TextPromptMessageContent(data=query)] prompt_message_contents = [TextPromptMessageContent(data=query)]
for file in files: for file in files:
prompt_message_contents.append(file.prompt_message_content) prompt_message_contents.append(file_manager.to_prompt_message_content(file))
prompt_messages.append(UserPromptMessage(content=prompt_message_contents)) prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
elif query: elif query:
@ -220,8 +223,8 @@ class AdvancedPromptTransform(PromptTransform):
return prompt_messages return prompt_messages
def _set_context_variable(self, context: str, prompt_template: PromptTemplateParser, prompt_inputs: dict) -> dict: def _set_context_variable(self, context: str | None, parser: PromptTemplateParser, prompt_inputs: dict) -> dict:
if "#context#" in prompt_template.variable_keys: if "#context#" in parser.variable_keys:
if context: if context:
prompt_inputs["#context#"] = context prompt_inputs["#context#"] = context
else: else:
@ -229,8 +232,8 @@ class AdvancedPromptTransform(PromptTransform):
return prompt_inputs return prompt_inputs
def _set_query_variable(self, query: str, prompt_template: PromptTemplateParser, prompt_inputs: dict) -> dict: def _set_query_variable(self, query: str, parser: PromptTemplateParser, prompt_inputs: dict) -> dict:
if "#query#" in prompt_template.variable_keys: if "#query#" in parser.variable_keys:
if query: if query:
prompt_inputs["#query#"] = query prompt_inputs["#query#"] = query
else: else:
@ -244,16 +247,16 @@ class AdvancedPromptTransform(PromptTransform):
memory_config: MemoryConfig, memory_config: MemoryConfig,
raw_prompt: str, raw_prompt: str,
role_prefix: MemoryConfig.RolePrefix, role_prefix: MemoryConfig.RolePrefix,
prompt_template: PromptTemplateParser, parser: PromptTemplateParser,
prompt_inputs: dict, prompt_inputs: dict,
model_config: ModelConfigWithCredentialsEntity, model_config: ModelConfigWithCredentialsEntity,
) -> dict: ) -> dict:
if "#histories#" in prompt_template.variable_keys: if "#histories#" in parser.variable_keys:
if memory: if memory:
inputs = {"#histories#": "", **prompt_inputs} inputs = {"#histories#": "", **prompt_inputs}
prompt_template = PromptTemplateParser(template=raw_prompt, with_variable_tmpl=self.with_variable_tmpl) parser = PromptTemplateParser(template=raw_prompt, with_variable_tmpl=self.with_variable_tmpl)
prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs} prompt_inputs = {k: inputs[k] for k in parser.variable_keys if k in inputs}
tmp_human_message = UserPromptMessage(content=prompt_template.format(prompt_inputs)) tmp_human_message = UserPromptMessage(content=parser.format(prompt_inputs))
rest_tokens = self._calculate_rest_token([tmp_human_message], model_config) rest_tokens = self._calculate_rest_token([tmp_human_message], model_config)

@ -5,9 +5,11 @@ from typing import TYPE_CHECKING, Optional
from core.app.app_config.entities import PromptTemplateEntity from core.app.app_config.entities import PromptTemplateEntity
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.file import file_manager
from core.memory.token_buffer_memory import TokenBufferMemory from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_runtime.entities.message_entities import ( from core.model_runtime.entities.message_entities import (
PromptMessage, PromptMessage,
PromptMessageContent,
SystemPromptMessage, SystemPromptMessage,
TextPromptMessageContent, TextPromptMessageContent,
UserPromptMessage, UserPromptMessage,
@ -18,10 +20,10 @@ from core.prompt.utils.prompt_template_parser import PromptTemplateParser
from models.model import AppMode from models.model import AppMode
if TYPE_CHECKING: if TYPE_CHECKING:
from core.file.file_obj import FileVar from core.file.models import File
class ModelMode(enum.Enum): class ModelMode(str, enum.Enum):
COMPLETION = "completion" COMPLETION = "completion"
CHAT = "chat" CHAT = "chat"
@ -53,7 +55,7 @@ class SimplePromptTransform(PromptTransform):
prompt_template_entity: PromptTemplateEntity, prompt_template_entity: PromptTemplateEntity,
inputs: dict, inputs: dict,
query: str, query: str,
files: list["FileVar"], files: list["File"],
context: Optional[str], context: Optional[str],
memory: Optional[TokenBufferMemory], memory: Optional[TokenBufferMemory],
model_config: ModelConfigWithCredentialsEntity, model_config: ModelConfigWithCredentialsEntity,
@ -169,7 +171,7 @@ class SimplePromptTransform(PromptTransform):
inputs: dict, inputs: dict,
query: str, query: str,
context: Optional[str], context: Optional[str],
files: list["FileVar"], files: list["File"],
memory: Optional[TokenBufferMemory], memory: Optional[TokenBufferMemory],
model_config: ModelConfigWithCredentialsEntity, model_config: ModelConfigWithCredentialsEntity,
) -> tuple[list[PromptMessage], Optional[list[str]]]: ) -> tuple[list[PromptMessage], Optional[list[str]]]:
@ -214,7 +216,7 @@ class SimplePromptTransform(PromptTransform):
inputs: dict, inputs: dict,
query: str, query: str,
context: Optional[str], context: Optional[str],
files: list["FileVar"], files: list["File"],
memory: Optional[TokenBufferMemory], memory: Optional[TokenBufferMemory],
model_config: ModelConfigWithCredentialsEntity, model_config: ModelConfigWithCredentialsEntity,
) -> tuple[list[PromptMessage], Optional[list[str]]]: ) -> tuple[list[PromptMessage], Optional[list[str]]]:
@ -261,11 +263,12 @@ class SimplePromptTransform(PromptTransform):
return [self.get_last_user_message(prompt, files)], stops return [self.get_last_user_message(prompt, files)], stops
def get_last_user_message(self, prompt: str, files: list["FileVar"]) -> UserPromptMessage: def get_last_user_message(self, prompt: str, files: list["File"]) -> UserPromptMessage:
if files: if files:
prompt_message_contents = [TextPromptMessageContent(data=prompt)] prompt_message_contents: list[PromptMessageContent] = []
prompt_message_contents.append(TextPromptMessageContent(data=prompt))
for file in files: for file in files:
prompt_message_contents.append(file.prompt_message_content) prompt_message_contents.append(file_manager.to_prompt_message_content(file))
prompt_message = UserPromptMessage(content=prompt_message_contents) prompt_message = UserPromptMessage(content=prompt_message_contents)
else: else:

@ -1,7 +1,9 @@
from typing import Any
from constants import UUID_NIL from constants import UUID_NIL
def extract_thread_messages(messages: list[dict]) -> list[dict]: def extract_thread_messages(messages: list[Any]):
thread_messages = [] thread_messages = []
next_message = None next_message = None

@ -1,7 +1,8 @@
from typing import cast from typing import cast
from core.model_runtime.entities.message_entities import ( from core.model_runtime.entities import (
AssistantPromptMessage, AssistantPromptMessage,
AudioPromptMessageContent,
ImagePromptMessageContent, ImagePromptMessageContent,
PromptMessage, PromptMessage,
PromptMessageContentType, PromptMessageContentType,
@ -21,7 +22,7 @@ class PromptMessageUtil:
:return: :return:
""" """
prompts = [] prompts = []
if model_mode == ModelMode.CHAT.value: if model_mode == ModelMode.CHAT:
tool_calls = [] tool_calls = []
for prompt_message in prompt_messages: for prompt_message in prompt_messages:
if prompt_message.role == PromptMessageRole.USER: if prompt_message.role == PromptMessageRole.USER:
@ -51,11 +52,9 @@ class PromptMessageUtil:
files = [] files = []
if isinstance(prompt_message.content, list): if isinstance(prompt_message.content, list):
for content in prompt_message.content: for content in prompt_message.content:
if content.type == PromptMessageContentType.TEXT: if isinstance(content, TextPromptMessageContent):
content = cast(TextPromptMessageContent, content)
text += content.data text += content.data
else: elif isinstance(content, ImagePromptMessageContent):
content = cast(ImagePromptMessageContent, content)
files.append( files.append(
{ {
"type": "image", "type": "image",
@ -63,6 +62,14 @@ class PromptMessageUtil:
"detail": content.detail.value, "detail": content.detail.value,
} }
) )
elif isinstance(content, AudioPromptMessageContent):
files.append(
{
"type": "audio",
"data": content.data[:10] + "...[TRUNCATED]..." + content.data[-10:],
"format": content.format,
}
)
else: else:
text = prompt_message.content text = prompt_message.content

@ -14,6 +14,7 @@ from core.entities.provider_entities import (
CustomProviderConfiguration, CustomProviderConfiguration,
ModelLoadBalancingConfiguration, ModelLoadBalancingConfiguration,
ModelSettings, ModelSettings,
ProviderQuotaType,
QuotaConfiguration, QuotaConfiguration,
SystemConfiguration, SystemConfiguration,
) )
@ -31,7 +32,6 @@ from models.provider import (
Provider, Provider,
ProviderModel, ProviderModel,
ProviderModelSetting, ProviderModelSetting,
ProviderQuotaType,
ProviderType, ProviderType,
TenantDefaultModel, TenantDefaultModel,
TenantPreferredModelProvider, TenantPreferredModelProvider,

@ -1,14 +1,14 @@
from typing import Optional from typing import Optional
from core.model_manager import ModelManager from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.errors.invoke import InvokeAuthorizationError from core.model_runtime.errors.invoke import InvokeAuthorizationError
from core.rag.data_post_processor.reorder import ReorderRunner from core.rag.data_post_processor.reorder import ReorderRunner
from core.rag.models.document import Document from core.rag.models.document import Document
from core.rag.rerank.constants.rerank_mode import RerankMode
from core.rag.rerank.entity.weight import KeywordSetting, VectorSetting, Weights from core.rag.rerank.entity.weight import KeywordSetting, VectorSetting, Weights
from core.rag.rerank.rerank_model import RerankModelRunner from core.rag.rerank.rerank_base import BaseRerankRunner
from core.rag.rerank.weight_rerank import WeightRerankRunner from core.rag.rerank.rerank_factory import RerankRunnerFactory
from core.rag.rerank.rerank_type import RerankMode
class DataPostProcessor: class DataPostProcessor:
@ -47,11 +47,12 @@ class DataPostProcessor:
tenant_id: str, tenant_id: str,
reranking_model: Optional[dict] = None, reranking_model: Optional[dict] = None,
weights: Optional[dict] = None, weights: Optional[dict] = None,
) -> Optional[RerankModelRunner | WeightRerankRunner]: ) -> Optional[BaseRerankRunner]:
if reranking_mode == RerankMode.WEIGHTED_SCORE.value and weights: if reranking_mode == RerankMode.WEIGHTED_SCORE.value and weights:
return WeightRerankRunner( runner = RerankRunnerFactory.create_rerank_runner(
tenant_id, runner_type=reranking_mode,
Weights( tenant_id=tenant_id,
weights=Weights(
vector_setting=VectorSetting( vector_setting=VectorSetting(
vector_weight=weights["vector_setting"]["vector_weight"], vector_weight=weights["vector_setting"]["vector_weight"],
embedding_provider_name=weights["vector_setting"]["embedding_provider_name"], embedding_provider_name=weights["vector_setting"]["embedding_provider_name"],
@ -62,7 +63,23 @@ class DataPostProcessor:
), ),
), ),
) )
return runner
elif reranking_mode == RerankMode.RERANKING_MODEL.value: elif reranking_mode == RerankMode.RERANKING_MODEL.value:
rerank_model_instance = self._get_rerank_model_instance(tenant_id, reranking_model)
if rerank_model_instance is None:
return None
runner = RerankRunnerFactory.create_rerank_runner(
runner_type=reranking_mode, rerank_model_instance=rerank_model_instance
)
return runner
return None
def _get_reorder_runner(self, reorder_enabled) -> Optional[ReorderRunner]:
if reorder_enabled:
return ReorderRunner()
return None
def _get_rerank_model_instance(self, tenant_id: str, reranking_model: Optional[dict]) -> ModelInstance | None:
if reranking_model: if reranking_model:
try: try:
model_manager = ModelManager() model_manager = ModelManager()
@ -72,13 +89,7 @@ class DataPostProcessor:
model_type=ModelType.RERANK, model_type=ModelType.RERANK,
model=reranking_model["reranking_model_name"], model=reranking_model["reranking_model_name"],
) )
return rerank_model_instance
except InvokeAuthorizationError: except InvokeAuthorizationError:
return None return None
return RerankModelRunner(rerank_model_instance)
return None
return None
def _get_reorder_runner(self, reorder_enabled) -> Optional[ReorderRunner]:
if reorder_enabled:
return ReorderRunner()
return None return None

@ -6,7 +6,7 @@ from flask import Flask, current_app
from core.rag.data_post_processor.data_post_processor import DataPostProcessor from core.rag.data_post_processor.data_post_processor import DataPostProcessor
from core.rag.datasource.keyword.keyword_factory import Keyword from core.rag.datasource.keyword.keyword_factory import Keyword
from core.rag.datasource.vdb.vector_factory import Vector from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.rerank.constants.rerank_mode import RerankMode from core.rag.rerank.rerank_type import RerankMode
from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.rag.retrieval.retrieval_methods import RetrievalMethod
from extensions.ext_database import db from extensions.ext_database import db
from models.dataset import Dataset from models.dataset import Dataset

@ -9,10 +9,10 @@ _import_err_msg = (
) )
from configs import dify_config from configs import dify_config
from core.rag.datasource.entity.embedding import Embeddings
from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.datasource.vdb.vector_type import VectorType from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.embedding.embedding_base import Embeddings
from core.rag.models.document import Document from core.rag.models.document import Document
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client
from models.dataset import Dataset from models.dataset import Dataset

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save