Merge main into fix/chore-fix
commit
bedbd658fe
@ -0,0 +1,213 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
if os.environ.get("DEBUG", "false").lower() != "true":
|
||||||
|
from gevent import monkey
|
||||||
|
|
||||||
|
monkey.patch_all()
|
||||||
|
|
||||||
|
import grpc.experimental.gevent
|
||||||
|
|
||||||
|
grpc.experimental.gevent.init_gevent()
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import sys
|
||||||
|
from logging.handlers import RotatingFileHandler
|
||||||
|
|
||||||
|
from flask import Flask, Response, request
|
||||||
|
from flask_cors import CORS
|
||||||
|
from werkzeug.exceptions import Unauthorized
|
||||||
|
|
||||||
|
import contexts
|
||||||
|
from commands import register_commands
|
||||||
|
from configs import dify_config
|
||||||
|
from extensions import (
|
||||||
|
ext_celery,
|
||||||
|
ext_code_based_extension,
|
||||||
|
ext_compress,
|
||||||
|
ext_database,
|
||||||
|
ext_hosting_provider,
|
||||||
|
ext_login,
|
||||||
|
ext_mail,
|
||||||
|
ext_migrate,
|
||||||
|
ext_proxy_fix,
|
||||||
|
ext_redis,
|
||||||
|
ext_sentry,
|
||||||
|
ext_storage,
|
||||||
|
)
|
||||||
|
from extensions.ext_database import db
|
||||||
|
from extensions.ext_login import login_manager
|
||||||
|
from libs.passport import PassportService
|
||||||
|
from services.account_service import AccountService
|
||||||
|
|
||||||
|
|
||||||
|
class DifyApp(Flask):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
# ----------------------------
|
||||||
|
# Application Factory Function
|
||||||
|
# ----------------------------
|
||||||
|
def create_flask_app_with_configs() -> Flask:
|
||||||
|
"""
|
||||||
|
create a raw flask app
|
||||||
|
with configs loaded from .env file
|
||||||
|
"""
|
||||||
|
dify_app = DifyApp(__name__)
|
||||||
|
dify_app.config.from_mapping(dify_config.model_dump())
|
||||||
|
|
||||||
|
# populate configs into system environment variables
|
||||||
|
for key, value in dify_app.config.items():
|
||||||
|
if isinstance(value, str):
|
||||||
|
os.environ[key] = value
|
||||||
|
elif isinstance(value, int | float | bool):
|
||||||
|
os.environ[key] = str(value)
|
||||||
|
elif value is None:
|
||||||
|
os.environ[key] = ""
|
||||||
|
|
||||||
|
return dify_app
|
||||||
|
|
||||||
|
|
||||||
|
def create_app() -> Flask:
|
||||||
|
app = create_flask_app_with_configs()
|
||||||
|
|
||||||
|
app.secret_key = app.config["SECRET_KEY"]
|
||||||
|
|
||||||
|
log_handlers = None
|
||||||
|
log_file = app.config.get("LOG_FILE")
|
||||||
|
if log_file:
|
||||||
|
log_dir = os.path.dirname(log_file)
|
||||||
|
os.makedirs(log_dir, exist_ok=True)
|
||||||
|
log_handlers = [
|
||||||
|
RotatingFileHandler(
|
||||||
|
filename=log_file,
|
||||||
|
maxBytes=1024 * 1024 * 1024,
|
||||||
|
backupCount=5,
|
||||||
|
),
|
||||||
|
logging.StreamHandler(sys.stdout),
|
||||||
|
]
|
||||||
|
|
||||||
|
logging.basicConfig(
|
||||||
|
level=app.config.get("LOG_LEVEL"),
|
||||||
|
format=app.config.get("LOG_FORMAT"),
|
||||||
|
datefmt=app.config.get("LOG_DATEFORMAT"),
|
||||||
|
handlers=log_handlers,
|
||||||
|
force=True,
|
||||||
|
)
|
||||||
|
log_tz = app.config.get("LOG_TZ")
|
||||||
|
if log_tz:
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
import pytz
|
||||||
|
|
||||||
|
timezone = pytz.timezone(log_tz)
|
||||||
|
|
||||||
|
def time_converter(seconds):
|
||||||
|
return datetime.utcfromtimestamp(seconds).astimezone(timezone).timetuple()
|
||||||
|
|
||||||
|
for handler in logging.root.handlers:
|
||||||
|
handler.formatter.converter = time_converter
|
||||||
|
initialize_extensions(app)
|
||||||
|
register_blueprints(app)
|
||||||
|
register_commands(app)
|
||||||
|
|
||||||
|
return app
|
||||||
|
|
||||||
|
|
||||||
|
def initialize_extensions(app):
|
||||||
|
# Since the application instance is now created, pass it to each Flask
|
||||||
|
# extension instance to bind it to the Flask application instance (app)
|
||||||
|
ext_compress.init_app(app)
|
||||||
|
ext_code_based_extension.init()
|
||||||
|
ext_database.init_app(app)
|
||||||
|
ext_migrate.init(app, db)
|
||||||
|
ext_redis.init_app(app)
|
||||||
|
ext_storage.init_app(app)
|
||||||
|
ext_celery.init_app(app)
|
||||||
|
ext_login.init_app(app)
|
||||||
|
ext_mail.init_app(app)
|
||||||
|
ext_hosting_provider.init_app(app)
|
||||||
|
ext_sentry.init_app(app)
|
||||||
|
ext_proxy_fix.init_app(app)
|
||||||
|
|
||||||
|
|
||||||
|
# Flask-Login configuration
|
||||||
|
@login_manager.request_loader
|
||||||
|
def load_user_from_request(request_from_flask_login):
|
||||||
|
"""Load user based on the request."""
|
||||||
|
if request.blueprint not in {"console", "inner_api"}:
|
||||||
|
return None
|
||||||
|
# Check if the user_id contains a dot, indicating the old format
|
||||||
|
auth_header = request.headers.get("Authorization", "")
|
||||||
|
if not auth_header:
|
||||||
|
auth_token = request.args.get("_token")
|
||||||
|
if not auth_token:
|
||||||
|
raise Unauthorized("Invalid Authorization token.")
|
||||||
|
else:
|
||||||
|
if " " not in auth_header:
|
||||||
|
raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.")
|
||||||
|
auth_scheme, auth_token = auth_header.split(None, 1)
|
||||||
|
auth_scheme = auth_scheme.lower()
|
||||||
|
if auth_scheme != "bearer":
|
||||||
|
raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.")
|
||||||
|
|
||||||
|
decoded = PassportService().verify(auth_token)
|
||||||
|
user_id = decoded.get("user_id")
|
||||||
|
|
||||||
|
logged_in_account = AccountService.load_logged_in_account(account_id=user_id)
|
||||||
|
if logged_in_account:
|
||||||
|
contexts.tenant_id.set(logged_in_account.current_tenant_id)
|
||||||
|
return logged_in_account
|
||||||
|
|
||||||
|
|
||||||
|
@login_manager.unauthorized_handler
|
||||||
|
def unauthorized_handler():
|
||||||
|
"""Handle unauthorized requests."""
|
||||||
|
return Response(
|
||||||
|
json.dumps({"code": "unauthorized", "message": "Unauthorized."}),
|
||||||
|
status=401,
|
||||||
|
content_type="application/json",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# register blueprint routers
|
||||||
|
def register_blueprints(app):
|
||||||
|
from controllers.console import bp as console_app_bp
|
||||||
|
from controllers.files import bp as files_bp
|
||||||
|
from controllers.inner_api import bp as inner_api_bp
|
||||||
|
from controllers.service_api import bp as service_api_bp
|
||||||
|
from controllers.web import bp as web_bp
|
||||||
|
|
||||||
|
CORS(
|
||||||
|
service_api_bp,
|
||||||
|
allow_headers=["Content-Type", "Authorization", "X-App-Code"],
|
||||||
|
methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],
|
||||||
|
)
|
||||||
|
app.register_blueprint(service_api_bp)
|
||||||
|
|
||||||
|
CORS(
|
||||||
|
web_bp,
|
||||||
|
resources={r"/*": {"origins": app.config["WEB_API_CORS_ALLOW_ORIGINS"]}},
|
||||||
|
supports_credentials=True,
|
||||||
|
allow_headers=["Content-Type", "Authorization", "X-App-Code"],
|
||||||
|
methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],
|
||||||
|
expose_headers=["X-Version", "X-Env"],
|
||||||
|
)
|
||||||
|
|
||||||
|
app.register_blueprint(web_bp)
|
||||||
|
|
||||||
|
CORS(
|
||||||
|
console_app_bp,
|
||||||
|
resources={r"/*": {"origins": app.config["CONSOLE_CORS_ALLOW_ORIGINS"]}},
|
||||||
|
supports_credentials=True,
|
||||||
|
allow_headers=["Content-Type", "Authorization"],
|
||||||
|
methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],
|
||||||
|
expose_headers=["X-Version", "X-Env"],
|
||||||
|
)
|
||||||
|
|
||||||
|
app.register_blueprint(console_app_bp)
|
||||||
|
|
||||||
|
CORS(files_bp, allow_headers=["Content-Type"], methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"])
|
||||||
|
app.register_blueprint(files_bp)
|
||||||
|
|
||||||
|
app.register_blueprint(inner_api_bp)
|
||||||
@ -1,2 +1,22 @@
|
|||||||
|
from configs import dify_config
|
||||||
|
|
||||||
HIDDEN_VALUE = "[__HIDDEN__]"
|
HIDDEN_VALUE = "[__HIDDEN__]"
|
||||||
UUID_NIL = "00000000-0000-0000-0000-000000000000"
|
UUID_NIL = "00000000-0000-0000-0000-000000000000"
|
||||||
|
|
||||||
|
IMAGE_EXTENSIONS = ["jpg", "jpeg", "png", "webp", "gif", "svg"]
|
||||||
|
IMAGE_EXTENSIONS.extend([ext.upper() for ext in IMAGE_EXTENSIONS])
|
||||||
|
|
||||||
|
VIDEO_EXTENSIONS = ["mp4", "mov", "mpeg", "mpga"]
|
||||||
|
VIDEO_EXTENSIONS.extend([ext.upper() for ext in VIDEO_EXTENSIONS])
|
||||||
|
|
||||||
|
AUDIO_EXTENSIONS = ["mp3", "m4a", "wav", "webm", "amr"]
|
||||||
|
AUDIO_EXTENSIONS.extend([ext.upper() for ext in AUDIO_EXTENSIONS])
|
||||||
|
|
||||||
|
|
||||||
|
if dify_config.ETL_TYPE == "Unstructured":
|
||||||
|
DOCUMENT_EXTENSIONS = ["txt", "markdown", "md", "pdf", "html", "htm", "xlsx", "xls"]
|
||||||
|
DOCUMENT_EXTENSIONS.extend(("docx", "csv", "eml", "msg", "pptx", "ppt", "xml", "epub"))
|
||||||
|
DOCUMENT_EXTENSIONS.extend([ext.upper() for ext in DOCUMENT_EXTENSIONS])
|
||||||
|
else:
|
||||||
|
DOCUMENT_EXTENSIONS = ["txt", "markdown", "md", "pdf", "html", "htm", "xlsx", "xls", "docx", "csv"]
|
||||||
|
DOCUMENT_EXTENSIONS.extend([ext.upper() for ext in DOCUMENT_EXTENSIONS])
|
||||||
|
|||||||
@ -1,7 +1,9 @@
|
|||||||
from contextvars import ContextVar
|
from contextvars import ContextVar
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
from core.workflow.entities.variable_pool import VariablePool
|
from core.workflow.entities.variable_pool import VariablePool
|
||||||
|
|
||||||
tenant_id: ContextVar[str] = ContextVar("tenant_id")
|
tenant_id: ContextVar[str] = ContextVar("tenant_id")
|
||||||
|
|
||||||
workflow_variable_pool: ContextVar[VariablePool] = ContextVar("workflow_variable_pool")
|
workflow_variable_pool: ContextVar["VariablePool"] = ContextVar("workflow_variable_pool")
|
||||||
|
|||||||
@ -1,18 +0,0 @@
|
|||||||
import re
|
|
||||||
|
|
||||||
from core.workflow.entities.variable_pool import VariablePool
|
|
||||||
|
|
||||||
from . import SegmentGroup, factory
|
|
||||||
|
|
||||||
VARIABLE_PATTERN = re.compile(r"\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10})#\}\}")
|
|
||||||
|
|
||||||
|
|
||||||
def convert_template(*, template: str, variable_pool: VariablePool):
|
|
||||||
parts = re.split(VARIABLE_PATTERN, template)
|
|
||||||
segments = []
|
|
||||||
for part in filter(lambda x: x, parts):
|
|
||||||
if "." in part and (value := variable_pool.get(part.split("."))):
|
|
||||||
segments.append(value)
|
|
||||||
else:
|
|
||||||
segments.append(factory.build_segment(part))
|
|
||||||
return SegmentGroup(value=segments)
|
|
||||||
@ -1,29 +0,0 @@
|
|||||||
import enum
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
|
|
||||||
class PromptMessageFileType(enum.Enum):
|
|
||||||
IMAGE = "image"
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def value_of(value):
|
|
||||||
for member in PromptMessageFileType:
|
|
||||||
if member.value == value:
|
|
||||||
return member
|
|
||||||
raise ValueError(f"No matching enum found for value '{value}'")
|
|
||||||
|
|
||||||
|
|
||||||
class PromptMessageFile(BaseModel):
|
|
||||||
type: PromptMessageFileType
|
|
||||||
data: Any = None
|
|
||||||
|
|
||||||
|
|
||||||
class ImagePromptMessageFile(PromptMessageFile):
|
|
||||||
class DETAIL(enum.Enum):
|
|
||||||
LOW = "low"
|
|
||||||
HIGH = "high"
|
|
||||||
|
|
||||||
type: PromptMessageFileType = PromptMessageFileType.IMAGE
|
|
||||||
detail: DETAIL = DETAIL.LOW
|
|
||||||
@ -0,0 +1,19 @@
|
|||||||
|
from .constants import FILE_MODEL_IDENTITY
|
||||||
|
from .enums import ArrayFileAttribute, FileAttribute, FileBelongsTo, FileTransferMethod, FileType
|
||||||
|
from .models import (
|
||||||
|
File,
|
||||||
|
FileExtraConfig,
|
||||||
|
ImageConfig,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"FileType",
|
||||||
|
"FileExtraConfig",
|
||||||
|
"FileTransferMethod",
|
||||||
|
"FileBelongsTo",
|
||||||
|
"File",
|
||||||
|
"ImageConfig",
|
||||||
|
"FileAttribute",
|
||||||
|
"ArrayFileAttribute",
|
||||||
|
"FILE_MODEL_IDENTITY",
|
||||||
|
]
|
||||||
@ -0,0 +1 @@
|
|||||||
|
FILE_MODEL_IDENTITY = "__dify__file__"
|
||||||
@ -0,0 +1,55 @@
|
|||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
|
||||||
|
class FileType(str, Enum):
|
||||||
|
IMAGE = "image"
|
||||||
|
DOCUMENT = "document"
|
||||||
|
AUDIO = "audio"
|
||||||
|
VIDEO = "video"
|
||||||
|
CUSTOM = "custom"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def value_of(value):
|
||||||
|
for member in FileType:
|
||||||
|
if member.value == value:
|
||||||
|
return member
|
||||||
|
raise ValueError(f"No matching enum found for value '{value}'")
|
||||||
|
|
||||||
|
|
||||||
|
class FileTransferMethod(str, Enum):
|
||||||
|
REMOTE_URL = "remote_url"
|
||||||
|
LOCAL_FILE = "local_file"
|
||||||
|
TOOL_FILE = "tool_file"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def value_of(value):
|
||||||
|
for member in FileTransferMethod:
|
||||||
|
if member.value == value:
|
||||||
|
return member
|
||||||
|
raise ValueError(f"No matching enum found for value '{value}'")
|
||||||
|
|
||||||
|
|
||||||
|
class FileBelongsTo(str, Enum):
|
||||||
|
USER = "user"
|
||||||
|
ASSISTANT = "assistant"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def value_of(value):
|
||||||
|
for member in FileBelongsTo:
|
||||||
|
if member.value == value:
|
||||||
|
return member
|
||||||
|
raise ValueError(f"No matching enum found for value '{value}'")
|
||||||
|
|
||||||
|
|
||||||
|
class FileAttribute(str, Enum):
|
||||||
|
TYPE = "type"
|
||||||
|
SIZE = "size"
|
||||||
|
NAME = "name"
|
||||||
|
MIME_TYPE = "mime_type"
|
||||||
|
TRANSFER_METHOD = "transfer_method"
|
||||||
|
URL = "url"
|
||||||
|
EXTENSION = "extension"
|
||||||
|
|
||||||
|
|
||||||
|
class ArrayFileAttribute(str, Enum):
|
||||||
|
LENGTH = "length"
|
||||||
@ -0,0 +1,156 @@
|
|||||||
|
import base64
|
||||||
|
|
||||||
|
from configs import dify_config
|
||||||
|
from core.file import file_repository
|
||||||
|
from core.helper import ssrf_proxy
|
||||||
|
from core.model_runtime.entities import AudioPromptMessageContent, ImagePromptMessageContent
|
||||||
|
from extensions.ext_database import db
|
||||||
|
from extensions.ext_storage import storage
|
||||||
|
|
||||||
|
from . import helpers
|
||||||
|
from .enums import FileAttribute
|
||||||
|
from .models import File, FileTransferMethod, FileType
|
||||||
|
from .tool_file_parser import ToolFileParser
|
||||||
|
|
||||||
|
|
||||||
|
def get_attr(*, file: File, attr: FileAttribute):
|
||||||
|
match attr:
|
||||||
|
case FileAttribute.TYPE:
|
||||||
|
return file.type.value
|
||||||
|
case FileAttribute.SIZE:
|
||||||
|
return file.size
|
||||||
|
case FileAttribute.NAME:
|
||||||
|
return file.filename
|
||||||
|
case FileAttribute.MIME_TYPE:
|
||||||
|
return file.mime_type
|
||||||
|
case FileAttribute.TRANSFER_METHOD:
|
||||||
|
return file.transfer_method.value
|
||||||
|
case FileAttribute.URL:
|
||||||
|
return file.remote_url
|
||||||
|
case FileAttribute.EXTENSION:
|
||||||
|
return file.extension
|
||||||
|
case _:
|
||||||
|
raise ValueError(f"Invalid file attribute: {attr}")
|
||||||
|
|
||||||
|
|
||||||
|
def to_prompt_message_content(f: File, /):
|
||||||
|
"""
|
||||||
|
Convert a File object to an ImagePromptMessageContent object.
|
||||||
|
|
||||||
|
This function takes a File object and converts it to an ImagePromptMessageContent
|
||||||
|
object, which can be used as a prompt for image-based AI models.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file (File): The File object to convert. Must be of type FileType.IMAGE.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ImagePromptMessageContent: An object containing the image data and detail level.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the file is not an image or if the file data is missing.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
The detail level of the image prompt is determined by the file's extra_config.
|
||||||
|
If not specified, it defaults to ImagePromptMessageContent.DETAIL.LOW.
|
||||||
|
"""
|
||||||
|
match f.type:
|
||||||
|
case FileType.IMAGE:
|
||||||
|
if dify_config.MULTIMODAL_SEND_IMAGE_FORMAT == "url":
|
||||||
|
data = _to_url(f)
|
||||||
|
else:
|
||||||
|
data = _to_base64_data_string(f)
|
||||||
|
|
||||||
|
if f._extra_config and f._extra_config.image_config and f._extra_config.image_config.detail:
|
||||||
|
detail = f._extra_config.image_config.detail
|
||||||
|
else:
|
||||||
|
detail = ImagePromptMessageContent.DETAIL.LOW
|
||||||
|
|
||||||
|
return ImagePromptMessageContent(data=data, detail=detail)
|
||||||
|
case FileType.AUDIO:
|
||||||
|
encoded_string = _file_to_encoded_string(f)
|
||||||
|
if f.extension is None:
|
||||||
|
raise ValueError("Missing file extension")
|
||||||
|
return AudioPromptMessageContent(data=encoded_string, format=f.extension.lstrip("."))
|
||||||
|
case _:
|
||||||
|
raise ValueError(f"file type {f.type} is not supported")
|
||||||
|
|
||||||
|
|
||||||
|
def download(f: File, /):
|
||||||
|
upload_file = file_repository.get_upload_file(session=db.session(), file=f)
|
||||||
|
return _download_file_content(upload_file.key)
|
||||||
|
|
||||||
|
|
||||||
|
def _download_file_content(path: str, /):
|
||||||
|
"""
|
||||||
|
Download and return the contents of a file as bytes.
|
||||||
|
|
||||||
|
This function loads the file from storage and ensures it's in bytes format.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path (str): The path to the file in storage.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bytes: The contents of the file as a bytes object.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the loaded file is not a bytes object.
|
||||||
|
"""
|
||||||
|
data = storage.load(path, stream=False)
|
||||||
|
if not isinstance(data, bytes):
|
||||||
|
raise ValueError(f"file {path} is not a bytes object")
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def _get_encoded_string(f: File, /):
|
||||||
|
match f.transfer_method:
|
||||||
|
case FileTransferMethod.REMOTE_URL:
|
||||||
|
response = ssrf_proxy.get(f.remote_url)
|
||||||
|
response.raise_for_status()
|
||||||
|
content = response.content
|
||||||
|
encoded_string = base64.b64encode(content).decode("utf-8")
|
||||||
|
return encoded_string
|
||||||
|
case FileTransferMethod.LOCAL_FILE:
|
||||||
|
upload_file = file_repository.get_upload_file(session=db.session(), file=f)
|
||||||
|
data = _download_file_content(upload_file.key)
|
||||||
|
encoded_string = base64.b64encode(data).decode("utf-8")
|
||||||
|
return encoded_string
|
||||||
|
case FileTransferMethod.TOOL_FILE:
|
||||||
|
tool_file = file_repository.get_tool_file(session=db.session(), file=f)
|
||||||
|
data = _download_file_content(tool_file.file_key)
|
||||||
|
encoded_string = base64.b64encode(data).decode("utf-8")
|
||||||
|
return encoded_string
|
||||||
|
case _:
|
||||||
|
raise ValueError(f"Unsupported transfer method: {f.transfer_method}")
|
||||||
|
|
||||||
|
|
||||||
|
def _to_base64_data_string(f: File, /):
|
||||||
|
encoded_string = _get_encoded_string(f)
|
||||||
|
return f"data:{f.mime_type};base64,{encoded_string}"
|
||||||
|
|
||||||
|
|
||||||
|
def _file_to_encoded_string(f: File, /):
|
||||||
|
match f.type:
|
||||||
|
case FileType.IMAGE:
|
||||||
|
return _to_base64_data_string(f)
|
||||||
|
case FileType.AUDIO:
|
||||||
|
return _get_encoded_string(f)
|
||||||
|
case _:
|
||||||
|
raise ValueError(f"file type {f.type} is not supported")
|
||||||
|
|
||||||
|
|
||||||
|
def _to_url(f: File, /):
|
||||||
|
if f.transfer_method == FileTransferMethod.REMOTE_URL:
|
||||||
|
if f.remote_url is None:
|
||||||
|
raise ValueError("Missing file remote_url")
|
||||||
|
return f.remote_url
|
||||||
|
elif f.transfer_method == FileTransferMethod.LOCAL_FILE:
|
||||||
|
if f.related_id is None:
|
||||||
|
raise ValueError("Missing file related_id")
|
||||||
|
return helpers.get_signed_file_url(upload_file_id=f.related_id)
|
||||||
|
elif f.transfer_method == FileTransferMethod.TOOL_FILE:
|
||||||
|
# add sign url
|
||||||
|
if f.related_id is None or f.extension is None:
|
||||||
|
raise ValueError("Missing file related_id or extension")
|
||||||
|
return ToolFileParser.get_tool_file_manager().sign_file(tool_file_id=f.related_id, extension=f.extension)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported transfer method: {f.transfer_method}")
|
||||||
@ -1,145 +0,0 @@
|
|||||||
import enum
|
|
||||||
from typing import Any, Optional
|
|
||||||
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
from core.file.tool_file_parser import ToolFileParser
|
|
||||||
from core.file.upload_file_parser import UploadFileParser
|
|
||||||
from core.model_runtime.entities.message_entities import ImagePromptMessageContent
|
|
||||||
from extensions.ext_database import db
|
|
||||||
|
|
||||||
|
|
||||||
class FileExtraConfig(BaseModel):
|
|
||||||
"""
|
|
||||||
File Upload Entity.
|
|
||||||
"""
|
|
||||||
|
|
||||||
image_config: Optional[dict[str, Any]] = None
|
|
||||||
|
|
||||||
|
|
||||||
class FileType(enum.Enum):
|
|
||||||
IMAGE = "image"
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def value_of(value):
|
|
||||||
for member in FileType:
|
|
||||||
if member.value == value:
|
|
||||||
return member
|
|
||||||
raise ValueError(f"No matching enum found for value '{value}'")
|
|
||||||
|
|
||||||
|
|
||||||
class FileTransferMethod(enum.Enum):
|
|
||||||
REMOTE_URL = "remote_url"
|
|
||||||
LOCAL_FILE = "local_file"
|
|
||||||
TOOL_FILE = "tool_file"
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def value_of(value):
|
|
||||||
for member in FileTransferMethod:
|
|
||||||
if member.value == value:
|
|
||||||
return member
|
|
||||||
raise ValueError(f"No matching enum found for value '{value}'")
|
|
||||||
|
|
||||||
|
|
||||||
class FileBelongsTo(enum.Enum):
|
|
||||||
USER = "user"
|
|
||||||
ASSISTANT = "assistant"
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def value_of(value):
|
|
||||||
for member in FileBelongsTo:
|
|
||||||
if member.value == value:
|
|
||||||
return member
|
|
||||||
raise ValueError(f"No matching enum found for value '{value}'")
|
|
||||||
|
|
||||||
|
|
||||||
class FileVar(BaseModel):
|
|
||||||
id: Optional[str] = None # message file id
|
|
||||||
tenant_id: str
|
|
||||||
type: FileType
|
|
||||||
transfer_method: FileTransferMethod
|
|
||||||
url: Optional[str] = None # remote url
|
|
||||||
related_id: Optional[str] = None
|
|
||||||
extra_config: Optional[FileExtraConfig] = None
|
|
||||||
filename: Optional[str] = None
|
|
||||||
extension: Optional[str] = None
|
|
||||||
mime_type: Optional[str] = None
|
|
||||||
|
|
||||||
def to_dict(self) -> dict:
|
|
||||||
return {
|
|
||||||
"__variant": self.__class__.__name__,
|
|
||||||
"tenant_id": self.tenant_id,
|
|
||||||
"type": self.type.value,
|
|
||||||
"transfer_method": self.transfer_method.value,
|
|
||||||
"url": self.preview_url,
|
|
||||||
"remote_url": self.url,
|
|
||||||
"related_id": self.related_id,
|
|
||||||
"filename": self.filename,
|
|
||||||
"extension": self.extension,
|
|
||||||
"mime_type": self.mime_type,
|
|
||||||
}
|
|
||||||
|
|
||||||
def to_markdown(self) -> str:
|
|
||||||
"""
|
|
||||||
Convert file to markdown
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
preview_url = self.preview_url
|
|
||||||
if self.type == FileType.IMAGE:
|
|
||||||
text = f''
|
|
||||||
else:
|
|
||||||
text = f"[{self.filename or preview_url}]({preview_url})"
|
|
||||||
|
|
||||||
return text
|
|
||||||
|
|
||||||
@property
|
|
||||||
def data(self) -> Optional[str]:
|
|
||||||
"""
|
|
||||||
Get image data, file signed url or base64 data
|
|
||||||
depending on config MULTIMODAL_SEND_IMAGE_FORMAT
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
return self._get_data()
|
|
||||||
|
|
||||||
@property
|
|
||||||
def preview_url(self) -> Optional[str]:
|
|
||||||
"""
|
|
||||||
Get signed preview url
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
return self._get_data(force_url=True)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def prompt_message_content(self) -> ImagePromptMessageContent:
|
|
||||||
if self.type == FileType.IMAGE:
|
|
||||||
image_config = self.extra_config.image_config
|
|
||||||
|
|
||||||
return ImagePromptMessageContent(
|
|
||||||
data=self.data,
|
|
||||||
detail=ImagePromptMessageContent.DETAIL.HIGH
|
|
||||||
if image_config.get("detail") == "high"
|
|
||||||
else ImagePromptMessageContent.DETAIL.LOW,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _get_data(self, force_url: bool = False) -> Optional[str]:
|
|
||||||
from models.model import UploadFile
|
|
||||||
|
|
||||||
if self.type == FileType.IMAGE:
|
|
||||||
if self.transfer_method == FileTransferMethod.REMOTE_URL:
|
|
||||||
return self.url
|
|
||||||
elif self.transfer_method == FileTransferMethod.LOCAL_FILE:
|
|
||||||
upload_file = (
|
|
||||||
db.session.query(UploadFile)
|
|
||||||
.filter(UploadFile.id == self.related_id, UploadFile.tenant_id == self.tenant_id)
|
|
||||||
.first()
|
|
||||||
)
|
|
||||||
|
|
||||||
return UploadFileParser.get_image_data(upload_file=upload_file, force_url=force_url)
|
|
||||||
elif self.transfer_method == FileTransferMethod.TOOL_FILE:
|
|
||||||
extension = self.extension
|
|
||||||
# add sign url
|
|
||||||
return ToolFileParser.get_tool_file_manager().sign_file(
|
|
||||||
tool_file_id=self.related_id, extension=extension
|
|
||||||
)
|
|
||||||
|
|
||||||
return None
|
|
||||||
@ -0,0 +1,32 @@
|
|||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from models import ToolFile, UploadFile
|
||||||
|
|
||||||
|
from .models import File
|
||||||
|
|
||||||
|
|
||||||
|
def get_upload_file(*, session: Session, file: File):
|
||||||
|
if file.related_id is None:
|
||||||
|
raise ValueError("Missing file related_id")
|
||||||
|
stmt = select(UploadFile).filter(
|
||||||
|
UploadFile.id == file.related_id,
|
||||||
|
UploadFile.tenant_id == file.tenant_id,
|
||||||
|
)
|
||||||
|
record = session.scalar(stmt)
|
||||||
|
if not record:
|
||||||
|
raise ValueError(f"upload file {file.related_id} not found")
|
||||||
|
return record
|
||||||
|
|
||||||
|
|
||||||
|
def get_tool_file(*, session: Session, file: File):
|
||||||
|
if file.related_id is None:
|
||||||
|
raise ValueError("Missing file related_id")
|
||||||
|
stmt = select(ToolFile).filter(
|
||||||
|
ToolFile.id == file.related_id,
|
||||||
|
ToolFile.tenant_id == file.tenant_id,
|
||||||
|
)
|
||||||
|
record = session.scalar(stmt)
|
||||||
|
if not record:
|
||||||
|
raise ValueError(f"tool file {file.related_id} not found")
|
||||||
|
return record
|
||||||
@ -0,0 +1,48 @@
|
|||||||
|
import base64
|
||||||
|
import hashlib
|
||||||
|
import hmac
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
|
||||||
|
from configs import dify_config
|
||||||
|
|
||||||
|
|
||||||
|
def get_signed_file_url(upload_file_id: str) -> str:
|
||||||
|
url = f"{dify_config.FILES_URL}/files/{upload_file_id}/file-preview"
|
||||||
|
|
||||||
|
timestamp = str(int(time.time()))
|
||||||
|
nonce = os.urandom(16).hex()
|
||||||
|
key = dify_config.SECRET_KEY.encode()
|
||||||
|
msg = f"file-preview|{upload_file_id}|{timestamp}|{nonce}"
|
||||||
|
sign = hmac.new(key, msg.encode(), hashlib.sha256).digest()
|
||||||
|
encoded_sign = base64.urlsafe_b64encode(sign).decode()
|
||||||
|
|
||||||
|
return f"{url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}"
|
||||||
|
|
||||||
|
|
||||||
|
def verify_image_signature(*, upload_file_id: str, timestamp: str, nonce: str, sign: str) -> bool:
|
||||||
|
data_to_sign = f"image-preview|{upload_file_id}|{timestamp}|{nonce}"
|
||||||
|
secret_key = dify_config.SECRET_KEY.encode()
|
||||||
|
recalculated_sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
|
||||||
|
recalculated_encoded_sign = base64.urlsafe_b64encode(recalculated_sign).decode()
|
||||||
|
|
||||||
|
# verify signature
|
||||||
|
if sign != recalculated_encoded_sign:
|
||||||
|
return False
|
||||||
|
|
||||||
|
current_time = int(time.time())
|
||||||
|
return current_time - int(timestamp) <= dify_config.FILES_ACCESS_TIMEOUT
|
||||||
|
|
||||||
|
|
||||||
|
def verify_file_signature(*, upload_file_id: str, timestamp: str, nonce: str, sign: str) -> bool:
|
||||||
|
data_to_sign = f"file-preview|{upload_file_id}|{timestamp}|{nonce}"
|
||||||
|
secret_key = dify_config.SECRET_KEY.encode()
|
||||||
|
recalculated_sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
|
||||||
|
recalculated_encoded_sign = base64.urlsafe_b64encode(recalculated_sign).decode()
|
||||||
|
|
||||||
|
# verify signature
|
||||||
|
if sign != recalculated_encoded_sign:
|
||||||
|
return False
|
||||||
|
|
||||||
|
current_time = int(time.time())
|
||||||
|
return current_time - int(timestamp) <= dify_config.FILES_ACCESS_TIMEOUT
|
||||||
@ -1,243 +0,0 @@
|
|||||||
import re
|
|
||||||
from collections.abc import Mapping, Sequence
|
|
||||||
from typing import Any, Union
|
|
||||||
from urllib.parse import parse_qs, urlparse
|
|
||||||
|
|
||||||
import requests
|
|
||||||
|
|
||||||
from core.file.file_obj import FileBelongsTo, FileExtraConfig, FileTransferMethod, FileType, FileVar
|
|
||||||
from extensions.ext_database import db
|
|
||||||
from models.account import Account
|
|
||||||
from models.model import EndUser, MessageFile, UploadFile
|
|
||||||
from services.file_service import IMAGE_EXTENSIONS
|
|
||||||
|
|
||||||
|
|
||||||
class MessageFileParser:
|
|
||||||
def __init__(self, tenant_id: str, app_id: str) -> None:
|
|
||||||
self.tenant_id = tenant_id
|
|
||||||
self.app_id = app_id
|
|
||||||
|
|
||||||
def validate_and_transform_files_arg(
|
|
||||||
self, files: Sequence[Mapping[str, Any]], file_extra_config: FileExtraConfig, user: Union[Account, EndUser]
|
|
||||||
) -> list[FileVar]:
|
|
||||||
"""
|
|
||||||
validate and transform files arg
|
|
||||||
|
|
||||||
:param files:
|
|
||||||
:param file_extra_config:
|
|
||||||
:param user:
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
for file in files:
|
|
||||||
if not isinstance(file, dict):
|
|
||||||
raise ValueError("Invalid file format, must be dict")
|
|
||||||
if not file.get("type"):
|
|
||||||
raise ValueError("Missing file type")
|
|
||||||
FileType.value_of(file.get("type"))
|
|
||||||
if not file.get("transfer_method"):
|
|
||||||
raise ValueError("Missing file transfer method")
|
|
||||||
FileTransferMethod.value_of(file.get("transfer_method"))
|
|
||||||
if file.get("transfer_method") == FileTransferMethod.REMOTE_URL.value:
|
|
||||||
if not file.get("url"):
|
|
||||||
raise ValueError("Missing file url")
|
|
||||||
if not file.get("url").startswith("http"):
|
|
||||||
raise ValueError("Invalid file url")
|
|
||||||
if file.get("transfer_method") == FileTransferMethod.LOCAL_FILE.value and not file.get("upload_file_id"):
|
|
||||||
raise ValueError("Missing file upload_file_id")
|
|
||||||
if file.get("transform_method") == FileTransferMethod.TOOL_FILE.value and not file.get("tool_file_id"):
|
|
||||||
raise ValueError("Missing file tool_file_id")
|
|
||||||
|
|
||||||
# transform files to file objs
|
|
||||||
type_file_objs = self._to_file_objs(files, file_extra_config)
|
|
||||||
|
|
||||||
# validate files
|
|
||||||
new_files = []
|
|
||||||
for file_type, file_objs in type_file_objs.items():
|
|
||||||
if file_type == FileType.IMAGE:
|
|
||||||
# parse and validate files
|
|
||||||
image_config = file_extra_config.image_config
|
|
||||||
|
|
||||||
# check if image file feature is enabled
|
|
||||||
if not image_config:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Validate number of files
|
|
||||||
if len(files) > image_config["number_limits"]:
|
|
||||||
raise ValueError(f"Number of image files exceeds the maximum limit {image_config['number_limits']}")
|
|
||||||
|
|
||||||
for file_obj in file_objs:
|
|
||||||
# Validate transfer method
|
|
||||||
if file_obj.transfer_method.value not in image_config["transfer_methods"]:
|
|
||||||
raise ValueError(f"Invalid transfer method: {file_obj.transfer_method.value}")
|
|
||||||
|
|
||||||
# Validate file type
|
|
||||||
if file_obj.type != FileType.IMAGE:
|
|
||||||
raise ValueError(f"Invalid file type: {file_obj.type}")
|
|
||||||
|
|
||||||
if file_obj.transfer_method == FileTransferMethod.REMOTE_URL:
|
|
||||||
# check remote url valid and is image
|
|
||||||
result, error = self._check_image_remote_url(file_obj.url)
|
|
||||||
if result is False:
|
|
||||||
raise ValueError(error)
|
|
||||||
elif file_obj.transfer_method == FileTransferMethod.LOCAL_FILE:
|
|
||||||
# get upload file from upload_file_id
|
|
||||||
upload_file = (
|
|
||||||
db.session.query(UploadFile)
|
|
||||||
.filter(
|
|
||||||
UploadFile.id == file_obj.related_id,
|
|
||||||
UploadFile.tenant_id == self.tenant_id,
|
|
||||||
UploadFile.created_by == user.id,
|
|
||||||
UploadFile.created_by_role == ("account" if isinstance(user, Account) else "end_user"),
|
|
||||||
UploadFile.extension.in_(IMAGE_EXTENSIONS),
|
|
||||||
)
|
|
||||||
.first()
|
|
||||||
)
|
|
||||||
|
|
||||||
# check upload file is belong to tenant and user
|
|
||||||
if not upload_file:
|
|
||||||
raise ValueError("Invalid upload file")
|
|
||||||
|
|
||||||
new_files.append(file_obj)
|
|
||||||
|
|
||||||
# return all file objs
|
|
||||||
return new_files
|
|
||||||
|
|
||||||
def transform_message_files(self, files: list[MessageFile], file_extra_config: FileExtraConfig):
|
|
||||||
"""
|
|
||||||
transform message files
|
|
||||||
|
|
||||||
:param files:
|
|
||||||
:param file_extra_config:
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
# transform files to file objs
|
|
||||||
type_file_objs = self._to_file_objs(files, file_extra_config)
|
|
||||||
|
|
||||||
# return all file objs
|
|
||||||
return [file_obj for file_objs in type_file_objs.values() for file_obj in file_objs]
|
|
||||||
|
|
||||||
def _to_file_objs(
|
|
||||||
self, files: list[Union[dict, MessageFile]], file_extra_config: FileExtraConfig
|
|
||||||
) -> dict[FileType, list[FileVar]]:
|
|
||||||
"""
|
|
||||||
transform files to file objs
|
|
||||||
|
|
||||||
:param files:
|
|
||||||
:param file_extra_config:
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
type_file_objs: dict[FileType, list[FileVar]] = {
|
|
||||||
# Currently only support image
|
|
||||||
FileType.IMAGE: []
|
|
||||||
}
|
|
||||||
|
|
||||||
if not files:
|
|
||||||
return type_file_objs
|
|
||||||
|
|
||||||
# group by file type and convert file args or message files to FileObj
|
|
||||||
for file in files:
|
|
||||||
if isinstance(file, MessageFile):
|
|
||||||
if file.belongs_to == FileBelongsTo.ASSISTANT.value:
|
|
||||||
continue
|
|
||||||
|
|
||||||
file_obj = self._to_file_obj(file, file_extra_config)
|
|
||||||
if file_obj.type not in type_file_objs:
|
|
||||||
continue
|
|
||||||
|
|
||||||
type_file_objs[file_obj.type].append(file_obj)
|
|
||||||
|
|
||||||
return type_file_objs
|
|
||||||
|
|
||||||
def _to_file_obj(self, file: Union[dict, MessageFile], file_extra_config: FileExtraConfig):
|
|
||||||
"""
|
|
||||||
transform file to file obj
|
|
||||||
|
|
||||||
:param file:
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
if isinstance(file, dict):
|
|
||||||
transfer_method = FileTransferMethod.value_of(file.get("transfer_method"))
|
|
||||||
if transfer_method != FileTransferMethod.TOOL_FILE:
|
|
||||||
return FileVar(
|
|
||||||
tenant_id=self.tenant_id,
|
|
||||||
type=FileType.value_of(file.get("type")),
|
|
||||||
transfer_method=transfer_method,
|
|
||||||
url=file.get("url") if transfer_method == FileTransferMethod.REMOTE_URL else None,
|
|
||||||
related_id=file.get("upload_file_id") if transfer_method == FileTransferMethod.LOCAL_FILE else None,
|
|
||||||
extra_config=file_extra_config,
|
|
||||||
)
|
|
||||||
return FileVar(
|
|
||||||
tenant_id=self.tenant_id,
|
|
||||||
type=FileType.value_of(file.get("type")),
|
|
||||||
transfer_method=transfer_method,
|
|
||||||
url=None,
|
|
||||||
related_id=file.get("tool_file_id"),
|
|
||||||
extra_config=file_extra_config,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
return FileVar(
|
|
||||||
id=file.id,
|
|
||||||
tenant_id=self.tenant_id,
|
|
||||||
type=FileType.value_of(file.type),
|
|
||||||
transfer_method=FileTransferMethod.value_of(file.transfer_method),
|
|
||||||
url=file.url,
|
|
||||||
related_id=file.upload_file_id or None,
|
|
||||||
extra_config=file_extra_config,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _check_image_remote_url(self, url):
|
|
||||||
try:
|
|
||||||
headers = {
|
|
||||||
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko)"
|
|
||||||
" Chrome/91.0.4472.124 Safari/537.36"
|
|
||||||
}
|
|
||||||
|
|
||||||
def is_s3_presigned_url(url):
|
|
||||||
try:
|
|
||||||
parsed_url = urlparse(url)
|
|
||||||
if "amazonaws.com" not in parsed_url.netloc:
|
|
||||||
return False
|
|
||||||
query_params = parse_qs(parsed_url.query)
|
|
||||||
|
|
||||||
def check_presign_v2(query_params):
|
|
||||||
required_params = ["Signature", "Expires"]
|
|
||||||
for param in required_params:
|
|
||||||
if param not in query_params:
|
|
||||||
return False
|
|
||||||
if not query_params["Expires"][0].isdigit():
|
|
||||||
return False
|
|
||||||
signature = query_params["Signature"][0]
|
|
||||||
if not re.match(r"^[A-Za-z0-9+/]+={0,2}$", signature):
|
|
||||||
return False
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
def check_presign_v4(query_params):
|
|
||||||
required_params = ["X-Amz-Signature", "X-Amz-Expires"]
|
|
||||||
for param in required_params:
|
|
||||||
if param not in query_params:
|
|
||||||
return False
|
|
||||||
if not query_params["X-Amz-Expires"][0].isdigit():
|
|
||||||
return False
|
|
||||||
signature = query_params["X-Amz-Signature"][0]
|
|
||||||
if not re.match(r"^[A-Za-z0-9+/]+={0,2}$", signature):
|
|
||||||
return False
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
return check_presign_v4(query_params) or check_presign_v2(query_params)
|
|
||||||
except Exception:
|
|
||||||
return False
|
|
||||||
|
|
||||||
if is_s3_presigned_url(url):
|
|
||||||
response = requests.get(url, headers=headers, allow_redirects=True)
|
|
||||||
if response.status_code in {200, 304}:
|
|
||||||
return True, ""
|
|
||||||
|
|
||||||
response = requests.head(url, headers=headers, allow_redirects=True)
|
|
||||||
if response.status_code in {200, 304}:
|
|
||||||
return True, ""
|
|
||||||
else:
|
|
||||||
return False, "URL does not exist."
|
|
||||||
except requests.RequestException as e:
|
|
||||||
return False, f"Error checking URL: {e}"
|
|
||||||
@ -0,0 +1,140 @@
|
|||||||
|
from collections.abc import Mapping, Sequence
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field, model_validator
|
||||||
|
|
||||||
|
from core.model_runtime.entities.message_entities import ImagePromptMessageContent
|
||||||
|
|
||||||
|
from . import helpers
|
||||||
|
from .constants import FILE_MODEL_IDENTITY
|
||||||
|
from .enums import FileTransferMethod, FileType
|
||||||
|
from .tool_file_parser import ToolFileParser
|
||||||
|
|
||||||
|
|
||||||
|
class ImageConfig(BaseModel):
|
||||||
|
"""
|
||||||
|
NOTE: This part of validation is deprecated, but still used in app features "Image Upload".
|
||||||
|
"""
|
||||||
|
|
||||||
|
number_limits: int = 0
|
||||||
|
transfer_methods: Sequence[FileTransferMethod] = Field(default_factory=list)
|
||||||
|
detail: ImagePromptMessageContent.DETAIL | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class FileExtraConfig(BaseModel):
|
||||||
|
"""
|
||||||
|
File Upload Entity.
|
||||||
|
"""
|
||||||
|
|
||||||
|
image_config: Optional[ImageConfig] = None
|
||||||
|
allowed_file_types: Sequence[FileType] = Field(default_factory=list)
|
||||||
|
allowed_extensions: Sequence[str] = Field(default_factory=list)
|
||||||
|
allowed_upload_methods: Sequence[FileTransferMethod] = Field(default_factory=list)
|
||||||
|
number_limits: int = 0
|
||||||
|
|
||||||
|
|
||||||
|
class File(BaseModel):
|
||||||
|
dify_model_identity: str = FILE_MODEL_IDENTITY
|
||||||
|
|
||||||
|
id: Optional[str] = None # message file id
|
||||||
|
tenant_id: str
|
||||||
|
type: FileType
|
||||||
|
transfer_method: FileTransferMethod
|
||||||
|
remote_url: Optional[str] = None # remote url
|
||||||
|
related_id: Optional[str] = None
|
||||||
|
filename: Optional[str] = None
|
||||||
|
extension: Optional[str] = Field(default=None, description="File extension, should contains dot")
|
||||||
|
mime_type: Optional[str] = None
|
||||||
|
size: int = -1
|
||||||
|
_extra_config: FileExtraConfig | None = None
|
||||||
|
|
||||||
|
def to_dict(self) -> Mapping[str, str | int | None]:
|
||||||
|
data = self.model_dump(mode="json")
|
||||||
|
return {
|
||||||
|
**data,
|
||||||
|
"url": self.generate_url(),
|
||||||
|
}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def markdown(self) -> str:
|
||||||
|
url = self.generate_url()
|
||||||
|
if self.type == FileType.IMAGE:
|
||||||
|
text = f''
|
||||||
|
else:
|
||||||
|
text = f"[{self.filename or url}]({url})"
|
||||||
|
|
||||||
|
return text
|
||||||
|
|
||||||
|
def generate_url(self) -> Optional[str]:
|
||||||
|
if self.type == FileType.IMAGE:
|
||||||
|
if self.transfer_method == FileTransferMethod.REMOTE_URL:
|
||||||
|
return self.remote_url
|
||||||
|
elif self.transfer_method == FileTransferMethod.LOCAL_FILE:
|
||||||
|
if self.related_id is None:
|
||||||
|
raise ValueError("Missing file related_id")
|
||||||
|
return helpers.get_signed_file_url(upload_file_id=self.related_id)
|
||||||
|
elif self.transfer_method == FileTransferMethod.TOOL_FILE:
|
||||||
|
assert self.related_id is not None
|
||||||
|
assert self.extension is not None
|
||||||
|
return ToolFileParser.get_tool_file_manager().sign_file(
|
||||||
|
tool_file_id=self.related_id, extension=self.extension
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if self.transfer_method == FileTransferMethod.REMOTE_URL:
|
||||||
|
return self.remote_url
|
||||||
|
elif self.transfer_method == FileTransferMethod.LOCAL_FILE:
|
||||||
|
if self.related_id is None:
|
||||||
|
raise ValueError("Missing file related_id")
|
||||||
|
return helpers.get_signed_file_url(upload_file_id=self.related_id)
|
||||||
|
elif self.transfer_method == FileTransferMethod.TOOL_FILE:
|
||||||
|
assert self.related_id is not None
|
||||||
|
assert self.extension is not None
|
||||||
|
return ToolFileParser.get_tool_file_manager().sign_file(
|
||||||
|
tool_file_id=self.related_id, extension=self.extension
|
||||||
|
)
|
||||||
|
|
||||||
|
@model_validator(mode="after")
|
||||||
|
def validate_after(self):
|
||||||
|
match self.transfer_method:
|
||||||
|
case FileTransferMethod.REMOTE_URL:
|
||||||
|
if not self.remote_url:
|
||||||
|
raise ValueError("Missing file url")
|
||||||
|
if not isinstance(self.remote_url, str) or not self.remote_url.startswith("http"):
|
||||||
|
raise ValueError("Invalid file url")
|
||||||
|
case FileTransferMethod.LOCAL_FILE:
|
||||||
|
if not self.related_id:
|
||||||
|
raise ValueError("Missing file related_id")
|
||||||
|
case FileTransferMethod.TOOL_FILE:
|
||||||
|
if not self.related_id:
|
||||||
|
raise ValueError("Missing file related_id")
|
||||||
|
|
||||||
|
# Validate the extra config.
|
||||||
|
if not self._extra_config:
|
||||||
|
return self
|
||||||
|
|
||||||
|
if self._extra_config.allowed_file_types:
|
||||||
|
if self.type not in self._extra_config.allowed_file_types and self.type != FileType.CUSTOM:
|
||||||
|
raise ValueError(f"Invalid file type: {self.type}")
|
||||||
|
|
||||||
|
if self._extra_config.allowed_extensions and self.extension not in self._extra_config.allowed_extensions:
|
||||||
|
raise ValueError(f"Invalid file extension: {self.extension}")
|
||||||
|
|
||||||
|
if (
|
||||||
|
self._extra_config.allowed_upload_methods
|
||||||
|
and self.transfer_method not in self._extra_config.allowed_upload_methods
|
||||||
|
):
|
||||||
|
raise ValueError(f"Invalid transfer method: {self.transfer_method}")
|
||||||
|
|
||||||
|
match self.type:
|
||||||
|
case FileType.IMAGE:
|
||||||
|
# NOTE: This part of validation is deprecated, but still used in app features "Image Upload".
|
||||||
|
if not self._extra_config.image_config:
|
||||||
|
return self
|
||||||
|
# TODO: skip check if transfer_methods is empty, because many test cases are not setting this field
|
||||||
|
if (
|
||||||
|
self._extra_config.image_config.transfer_methods
|
||||||
|
and self.transfer_method not in self._extra_config.image_config.transfer_methods
|
||||||
|
):
|
||||||
|
raise ValueError(f"Invalid transfer method: {self.transfer_method}")
|
||||||
|
|
||||||
|
return self
|
||||||
@ -0,0 +1,38 @@
|
|||||||
|
from .llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
|
||||||
|
from .message_entities import (
|
||||||
|
AssistantPromptMessage,
|
||||||
|
AudioPromptMessageContent,
|
||||||
|
ImagePromptMessageContent,
|
||||||
|
PromptMessage,
|
||||||
|
PromptMessageContent,
|
||||||
|
PromptMessageContentType,
|
||||||
|
PromptMessageRole,
|
||||||
|
PromptMessageTool,
|
||||||
|
SystemPromptMessage,
|
||||||
|
TextPromptMessageContent,
|
||||||
|
ToolPromptMessage,
|
||||||
|
UserPromptMessage,
|
||||||
|
)
|
||||||
|
from .model_entities import ModelPropertyKey
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"ImagePromptMessageContent",
|
||||||
|
"PromptMessage",
|
||||||
|
"PromptMessageRole",
|
||||||
|
"LLMUsage",
|
||||||
|
"ModelPropertyKey",
|
||||||
|
"AssistantPromptMessage",
|
||||||
|
"PromptMessage",
|
||||||
|
"PromptMessageContent",
|
||||||
|
"PromptMessageRole",
|
||||||
|
"SystemPromptMessage",
|
||||||
|
"TextPromptMessageContent",
|
||||||
|
"UserPromptMessage",
|
||||||
|
"PromptMessageTool",
|
||||||
|
"ToolPromptMessage",
|
||||||
|
"PromptMessageContentType",
|
||||||
|
"LLMResult",
|
||||||
|
"LLMResultChunk",
|
||||||
|
"LLMResultChunkDelta",
|
||||||
|
"AudioPromptMessageContent",
|
||||||
|
]
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue