feat/enhance the multi-modal support (#8818)
parent
7a1d6fe509
commit
e61752bd3a
@ -1,2 +1,22 @@
|
||||
from configs import dify_config
|
||||
|
||||
HIDDEN_VALUE = "[__HIDDEN__]"
|
||||
UUID_NIL = "00000000-0000-0000-0000-000000000000"
|
||||
|
||||
IMAGE_EXTENSIONS = ["jpg", "jpeg", "png", "webp", "gif", "svg"]
|
||||
IMAGE_EXTENSIONS.extend([ext.upper() for ext in IMAGE_EXTENSIONS])
|
||||
|
||||
VIDEO_EXTENSIONS = ["mp4", "mov", "mpeg", "mpga"]
|
||||
VIDEO_EXTENSIONS.extend([ext.upper() for ext in VIDEO_EXTENSIONS])
|
||||
|
||||
AUDIO_EXTENSIONS = ["mp3", "m4a", "wav", "webm", "amr"]
|
||||
AUDIO_EXTENSIONS.extend([ext.upper() for ext in AUDIO_EXTENSIONS])
|
||||
|
||||
|
||||
if dify_config.ETL_TYPE == "Unstructured":
|
||||
DOCUMENT_EXTENSIONS = ["txt", "markdown", "md", "pdf", "html", "htm", "xlsx", "xls"]
|
||||
DOCUMENT_EXTENSIONS.extend(("docx", "csv", "eml", "msg", "pptx", "ppt", "xml", "epub"))
|
||||
DOCUMENT_EXTENSIONS.extend([ext.upper() for ext in DOCUMENT_EXTENSIONS])
|
||||
else:
|
||||
DOCUMENT_EXTENSIONS = ["txt", "markdown", "md", "pdf", "html", "htm", "xlsx", "xls", "docx", "csv"]
|
||||
DOCUMENT_EXTENSIONS.extend([ext.upper() for ext in DOCUMENT_EXTENSIONS])
|
||||
|
||||
@ -1,7 +1,9 @@
|
||||
from contextvars import ContextVar
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
|
||||
tenant_id: ContextVar[str] = ContextVar("tenant_id")
|
||||
|
||||
workflow_variable_pool: ContextVar[VariablePool] = ContextVar("workflow_variable_pool")
|
||||
workflow_variable_pool: ContextVar["VariablePool"] = ContextVar("workflow_variable_pool")
|
||||
|
||||
@ -1,18 +0,0 @@
|
||||
import re
|
||||
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
|
||||
from . import SegmentGroup, factory
|
||||
|
||||
VARIABLE_PATTERN = re.compile(r"\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10})#\}\}")
|
||||
|
||||
|
||||
def convert_template(*, template: str, variable_pool: VariablePool):
|
||||
parts = re.split(VARIABLE_PATTERN, template)
|
||||
segments = []
|
||||
for part in filter(lambda x: x, parts):
|
||||
if "." in part and (value := variable_pool.get(part.split("."))):
|
||||
segments.append(value)
|
||||
else:
|
||||
segments.append(factory.build_segment(part))
|
||||
return SegmentGroup(value=segments)
|
||||
@ -1,29 +0,0 @@
|
||||
import enum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class PromptMessageFileType(enum.Enum):
|
||||
IMAGE = "image"
|
||||
|
||||
@staticmethod
|
||||
def value_of(value):
|
||||
for member in PromptMessageFileType:
|
||||
if member.value == value:
|
||||
return member
|
||||
raise ValueError(f"No matching enum found for value '{value}'")
|
||||
|
||||
|
||||
class PromptMessageFile(BaseModel):
|
||||
type: PromptMessageFileType
|
||||
data: Any = None
|
||||
|
||||
|
||||
class ImagePromptMessageFile(PromptMessageFile):
|
||||
class DETAIL(enum.Enum):
|
||||
LOW = "low"
|
||||
HIGH = "high"
|
||||
|
||||
type: PromptMessageFileType = PromptMessageFileType.IMAGE
|
||||
detail: DETAIL = DETAIL.LOW
|
||||
@ -0,0 +1,19 @@
|
||||
from .constants import FILE_MODEL_IDENTITY
|
||||
from .enums import ArrayFileAttribute, FileAttribute, FileBelongsTo, FileTransferMethod, FileType
|
||||
from .models import (
|
||||
File,
|
||||
FileExtraConfig,
|
||||
ImageConfig,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"FileType",
|
||||
"FileExtraConfig",
|
||||
"FileTransferMethod",
|
||||
"FileBelongsTo",
|
||||
"File",
|
||||
"ImageConfig",
|
||||
"FileAttribute",
|
||||
"ArrayFileAttribute",
|
||||
"FILE_MODEL_IDENTITY",
|
||||
]
|
||||
@ -0,0 +1 @@
|
||||
FILE_MODEL_IDENTITY = "__dify__file__"
|
||||
@ -0,0 +1,55 @@
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class FileType(str, Enum):
|
||||
IMAGE = "image"
|
||||
DOCUMENT = "document"
|
||||
AUDIO = "audio"
|
||||
VIDEO = "video"
|
||||
CUSTOM = "custom"
|
||||
|
||||
@staticmethod
|
||||
def value_of(value):
|
||||
for member in FileType:
|
||||
if member.value == value:
|
||||
return member
|
||||
raise ValueError(f"No matching enum found for value '{value}'")
|
||||
|
||||
|
||||
class FileTransferMethod(str, Enum):
|
||||
REMOTE_URL = "remote_url"
|
||||
LOCAL_FILE = "local_file"
|
||||
TOOL_FILE = "tool_file"
|
||||
|
||||
@staticmethod
|
||||
def value_of(value):
|
||||
for member in FileTransferMethod:
|
||||
if member.value == value:
|
||||
return member
|
||||
raise ValueError(f"No matching enum found for value '{value}'")
|
||||
|
||||
|
||||
class FileBelongsTo(str, Enum):
|
||||
USER = "user"
|
||||
ASSISTANT = "assistant"
|
||||
|
||||
@staticmethod
|
||||
def value_of(value):
|
||||
for member in FileBelongsTo:
|
||||
if member.value == value:
|
||||
return member
|
||||
raise ValueError(f"No matching enum found for value '{value}'")
|
||||
|
||||
|
||||
class FileAttribute(str, Enum):
|
||||
TYPE = "type"
|
||||
SIZE = "size"
|
||||
NAME = "name"
|
||||
MIME_TYPE = "mime_type"
|
||||
TRANSFER_METHOD = "transfer_method"
|
||||
URL = "url"
|
||||
EXTENSION = "extension"
|
||||
|
||||
|
||||
class ArrayFileAttribute(str, Enum):
|
||||
LENGTH = "length"
|
||||
@ -0,0 +1,156 @@
|
||||
import base64
|
||||
|
||||
from configs import dify_config
|
||||
from core.file import file_repository
|
||||
from core.helper import ssrf_proxy
|
||||
from core.model_runtime.entities import AudioPromptMessageContent, ImagePromptMessageContent
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_storage import storage
|
||||
|
||||
from . import helpers
|
||||
from .enums import FileAttribute
|
||||
from .models import File, FileTransferMethod, FileType
|
||||
from .tool_file_parser import ToolFileParser
|
||||
|
||||
|
||||
def get_attr(*, file: File, attr: FileAttribute):
|
||||
match attr:
|
||||
case FileAttribute.TYPE:
|
||||
return file.type.value
|
||||
case FileAttribute.SIZE:
|
||||
return file.size
|
||||
case FileAttribute.NAME:
|
||||
return file.filename
|
||||
case FileAttribute.MIME_TYPE:
|
||||
return file.mime_type
|
||||
case FileAttribute.TRANSFER_METHOD:
|
||||
return file.transfer_method.value
|
||||
case FileAttribute.URL:
|
||||
return file.remote_url
|
||||
case FileAttribute.EXTENSION:
|
||||
return file.extension
|
||||
case _:
|
||||
raise ValueError(f"Invalid file attribute: {attr}")
|
||||
|
||||
|
||||
def to_prompt_message_content(f: File, /):
|
||||
"""
|
||||
Convert a File object to an ImagePromptMessageContent object.
|
||||
|
||||
This function takes a File object and converts it to an ImagePromptMessageContent
|
||||
object, which can be used as a prompt for image-based AI models.
|
||||
|
||||
Args:
|
||||
file (File): The File object to convert. Must be of type FileType.IMAGE.
|
||||
|
||||
Returns:
|
||||
ImagePromptMessageContent: An object containing the image data and detail level.
|
||||
|
||||
Raises:
|
||||
ValueError: If the file is not an image or if the file data is missing.
|
||||
|
||||
Note:
|
||||
The detail level of the image prompt is determined by the file's extra_config.
|
||||
If not specified, it defaults to ImagePromptMessageContent.DETAIL.LOW.
|
||||
"""
|
||||
match f.type:
|
||||
case FileType.IMAGE:
|
||||
if dify_config.MULTIMODAL_SEND_IMAGE_FORMAT == "url":
|
||||
data = _to_url(f)
|
||||
else:
|
||||
data = _to_base64_data_string(f)
|
||||
|
||||
if f._extra_config and f._extra_config.image_config and f._extra_config.image_config.detail:
|
||||
detail = f._extra_config.image_config.detail
|
||||
else:
|
||||
detail = ImagePromptMessageContent.DETAIL.LOW
|
||||
|
||||
return ImagePromptMessageContent(data=data, detail=detail)
|
||||
case FileType.AUDIO:
|
||||
encoded_string = _file_to_encoded_string(f)
|
||||
if f.extension is None:
|
||||
raise ValueError("Missing file extension")
|
||||
return AudioPromptMessageContent(data=encoded_string, format=f.extension.lstrip("."))
|
||||
case _:
|
||||
raise ValueError(f"file type {f.type} is not supported")
|
||||
|
||||
|
||||
def download(f: File, /):
|
||||
upload_file = file_repository.get_upload_file(session=db.session(), file=f)
|
||||
return _download_file_content(upload_file.key)
|
||||
|
||||
|
||||
def _download_file_content(path: str, /):
|
||||
"""
|
||||
Download and return the contents of a file as bytes.
|
||||
|
||||
This function loads the file from storage and ensures it's in bytes format.
|
||||
|
||||
Args:
|
||||
path (str): The path to the file in storage.
|
||||
|
||||
Returns:
|
||||
bytes: The contents of the file as a bytes object.
|
||||
|
||||
Raises:
|
||||
ValueError: If the loaded file is not a bytes object.
|
||||
"""
|
||||
data = storage.load(path, stream=False)
|
||||
if not isinstance(data, bytes):
|
||||
raise ValueError(f"file {path} is not a bytes object")
|
||||
return data
|
||||
|
||||
|
||||
def _get_encoded_string(f: File, /):
|
||||
match f.transfer_method:
|
||||
case FileTransferMethod.REMOTE_URL:
|
||||
response = ssrf_proxy.get(f.remote_url)
|
||||
response.raise_for_status()
|
||||
content = response.content
|
||||
encoded_string = base64.b64encode(content).decode("utf-8")
|
||||
return encoded_string
|
||||
case FileTransferMethod.LOCAL_FILE:
|
||||
upload_file = file_repository.get_upload_file(session=db.session(), file=f)
|
||||
data = _download_file_content(upload_file.key)
|
||||
encoded_string = base64.b64encode(data).decode("utf-8")
|
||||
return encoded_string
|
||||
case FileTransferMethod.TOOL_FILE:
|
||||
tool_file = file_repository.get_tool_file(session=db.session(), file=f)
|
||||
data = _download_file_content(tool_file.file_key)
|
||||
encoded_string = base64.b64encode(data).decode("utf-8")
|
||||
return encoded_string
|
||||
case _:
|
||||
raise ValueError(f"Unsupported transfer method: {f.transfer_method}")
|
||||
|
||||
|
||||
def _to_base64_data_string(f: File, /):
|
||||
encoded_string = _get_encoded_string(f)
|
||||
return f"data:{f.mime_type};base64,{encoded_string}"
|
||||
|
||||
|
||||
def _file_to_encoded_string(f: File, /):
|
||||
match f.type:
|
||||
case FileType.IMAGE:
|
||||
return _to_base64_data_string(f)
|
||||
case FileType.AUDIO:
|
||||
return _get_encoded_string(f)
|
||||
case _:
|
||||
raise ValueError(f"file type {f.type} is not supported")
|
||||
|
||||
|
||||
def _to_url(f: File, /):
|
||||
if f.transfer_method == FileTransferMethod.REMOTE_URL:
|
||||
if f.remote_url is None:
|
||||
raise ValueError("Missing file remote_url")
|
||||
return f.remote_url
|
||||
elif f.transfer_method == FileTransferMethod.LOCAL_FILE:
|
||||
if f.related_id is None:
|
||||
raise ValueError("Missing file related_id")
|
||||
return helpers.get_signed_file_url(upload_file_id=f.related_id)
|
||||
elif f.transfer_method == FileTransferMethod.TOOL_FILE:
|
||||
# add sign url
|
||||
if f.related_id is None or f.extension is None:
|
||||
raise ValueError("Missing file related_id or extension")
|
||||
return ToolFileParser.get_tool_file_manager().sign_file(tool_file_id=f.related_id, extension=f.extension)
|
||||
else:
|
||||
raise ValueError(f"Unsupported transfer method: {f.transfer_method}")
|
||||
@ -1,145 +0,0 @@
|
||||
import enum
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.file.tool_file_parser import ToolFileParser
|
||||
from core.file.upload_file_parser import UploadFileParser
|
||||
from core.model_runtime.entities.message_entities import ImagePromptMessageContent
|
||||
from extensions.ext_database import db
|
||||
|
||||
|
||||
class FileExtraConfig(BaseModel):
|
||||
"""
|
||||
File Upload Entity.
|
||||
"""
|
||||
|
||||
image_config: Optional[dict[str, Any]] = None
|
||||
|
||||
|
||||
class FileType(enum.Enum):
|
||||
IMAGE = "image"
|
||||
|
||||
@staticmethod
|
||||
def value_of(value):
|
||||
for member in FileType:
|
||||
if member.value == value:
|
||||
return member
|
||||
raise ValueError(f"No matching enum found for value '{value}'")
|
||||
|
||||
|
||||
class FileTransferMethod(enum.Enum):
|
||||
REMOTE_URL = "remote_url"
|
||||
LOCAL_FILE = "local_file"
|
||||
TOOL_FILE = "tool_file"
|
||||
|
||||
@staticmethod
|
||||
def value_of(value):
|
||||
for member in FileTransferMethod:
|
||||
if member.value == value:
|
||||
return member
|
||||
raise ValueError(f"No matching enum found for value '{value}'")
|
||||
|
||||
|
||||
class FileBelongsTo(enum.Enum):
|
||||
USER = "user"
|
||||
ASSISTANT = "assistant"
|
||||
|
||||
@staticmethod
|
||||
def value_of(value):
|
||||
for member in FileBelongsTo:
|
||||
if member.value == value:
|
||||
return member
|
||||
raise ValueError(f"No matching enum found for value '{value}'")
|
||||
|
||||
|
||||
class FileVar(BaseModel):
|
||||
id: Optional[str] = None # message file id
|
||||
tenant_id: str
|
||||
type: FileType
|
||||
transfer_method: FileTransferMethod
|
||||
url: Optional[str] = None # remote url
|
||||
related_id: Optional[str] = None
|
||||
extra_config: Optional[FileExtraConfig] = None
|
||||
filename: Optional[str] = None
|
||||
extension: Optional[str] = None
|
||||
mime_type: Optional[str] = None
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
"__variant": self.__class__.__name__,
|
||||
"tenant_id": self.tenant_id,
|
||||
"type": self.type.value,
|
||||
"transfer_method": self.transfer_method.value,
|
||||
"url": self.preview_url,
|
||||
"remote_url": self.url,
|
||||
"related_id": self.related_id,
|
||||
"filename": self.filename,
|
||||
"extension": self.extension,
|
||||
"mime_type": self.mime_type,
|
||||
}
|
||||
|
||||
def to_markdown(self) -> str:
|
||||
"""
|
||||
Convert file to markdown
|
||||
:return:
|
||||
"""
|
||||
preview_url = self.preview_url
|
||||
if self.type == FileType.IMAGE:
|
||||
text = f''
|
||||
else:
|
||||
text = f"[{self.filename or preview_url}]({preview_url})"
|
||||
|
||||
return text
|
||||
|
||||
@property
|
||||
def data(self) -> Optional[str]:
|
||||
"""
|
||||
Get image data, file signed url or base64 data
|
||||
depending on config MULTIMODAL_SEND_IMAGE_FORMAT
|
||||
:return:
|
||||
"""
|
||||
return self._get_data()
|
||||
|
||||
@property
|
||||
def preview_url(self) -> Optional[str]:
|
||||
"""
|
||||
Get signed preview url
|
||||
:return:
|
||||
"""
|
||||
return self._get_data(force_url=True)
|
||||
|
||||
@property
|
||||
def prompt_message_content(self) -> ImagePromptMessageContent:
|
||||
if self.type == FileType.IMAGE:
|
||||
image_config = self.extra_config.image_config
|
||||
|
||||
return ImagePromptMessageContent(
|
||||
data=self.data,
|
||||
detail=ImagePromptMessageContent.DETAIL.HIGH
|
||||
if image_config.get("detail") == "high"
|
||||
else ImagePromptMessageContent.DETAIL.LOW,
|
||||
)
|
||||
|
||||
def _get_data(self, force_url: bool = False) -> Optional[str]:
|
||||
from models.model import UploadFile
|
||||
|
||||
if self.type == FileType.IMAGE:
|
||||
if self.transfer_method == FileTransferMethod.REMOTE_URL:
|
||||
return self.url
|
||||
elif self.transfer_method == FileTransferMethod.LOCAL_FILE:
|
||||
upload_file = (
|
||||
db.session.query(UploadFile)
|
||||
.filter(UploadFile.id == self.related_id, UploadFile.tenant_id == self.tenant_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
return UploadFileParser.get_image_data(upload_file=upload_file, force_url=force_url)
|
||||
elif self.transfer_method == FileTransferMethod.TOOL_FILE:
|
||||
extension = self.extension
|
||||
# add sign url
|
||||
return ToolFileParser.get_tool_file_manager().sign_file(
|
||||
tool_file_id=self.related_id, extension=extension
|
||||
)
|
||||
|
||||
return None
|
||||
@ -0,0 +1,32 @@
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from models import ToolFile, UploadFile
|
||||
|
||||
from .models import File
|
||||
|
||||
|
||||
def get_upload_file(*, session: Session, file: File):
|
||||
if file.related_id is None:
|
||||
raise ValueError("Missing file related_id")
|
||||
stmt = select(UploadFile).filter(
|
||||
UploadFile.id == file.related_id,
|
||||
UploadFile.tenant_id == file.tenant_id,
|
||||
)
|
||||
record = session.scalar(stmt)
|
||||
if not record:
|
||||
raise ValueError(f"upload file {file.related_id} not found")
|
||||
return record
|
||||
|
||||
|
||||
def get_tool_file(*, session: Session, file: File):
|
||||
if file.related_id is None:
|
||||
raise ValueError("Missing file related_id")
|
||||
stmt = select(ToolFile).filter(
|
||||
ToolFile.id == file.related_id,
|
||||
ToolFile.tenant_id == file.tenant_id,
|
||||
)
|
||||
record = session.scalar(stmt)
|
||||
if not record:
|
||||
raise ValueError(f"tool file {file.related_id} not found")
|
||||
return record
|
||||
@ -0,0 +1,48 @@
|
||||
import base64
|
||||
import hashlib
|
||||
import hmac
|
||||
import os
|
||||
import time
|
||||
|
||||
from configs import dify_config
|
||||
|
||||
|
||||
def get_signed_file_url(upload_file_id: str) -> str:
|
||||
url = f"{dify_config.FILES_URL}/files/{upload_file_id}/file-preview"
|
||||
|
||||
timestamp = str(int(time.time()))
|
||||
nonce = os.urandom(16).hex()
|
||||
key = dify_config.SECRET_KEY.encode()
|
||||
msg = f"file-preview|{upload_file_id}|{timestamp}|{nonce}"
|
||||
sign = hmac.new(key, msg.encode(), hashlib.sha256).digest()
|
||||
encoded_sign = base64.urlsafe_b64encode(sign).decode()
|
||||
|
||||
return f"{url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}"
|
||||
|
||||
|
||||
def verify_image_signature(*, upload_file_id: str, timestamp: str, nonce: str, sign: str) -> bool:
|
||||
data_to_sign = f"image-preview|{upload_file_id}|{timestamp}|{nonce}"
|
||||
secret_key = dify_config.SECRET_KEY.encode()
|
||||
recalculated_sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
|
||||
recalculated_encoded_sign = base64.urlsafe_b64encode(recalculated_sign).decode()
|
||||
|
||||
# verify signature
|
||||
if sign != recalculated_encoded_sign:
|
||||
return False
|
||||
|
||||
current_time = int(time.time())
|
||||
return current_time - int(timestamp) <= dify_config.FILES_ACCESS_TIMEOUT
|
||||
|
||||
|
||||
def verify_file_signature(*, upload_file_id: str, timestamp: str, nonce: str, sign: str) -> bool:
|
||||
data_to_sign = f"file-preview|{upload_file_id}|{timestamp}|{nonce}"
|
||||
secret_key = dify_config.SECRET_KEY.encode()
|
||||
recalculated_sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
|
||||
recalculated_encoded_sign = base64.urlsafe_b64encode(recalculated_sign).decode()
|
||||
|
||||
# verify signature
|
||||
if sign != recalculated_encoded_sign:
|
||||
return False
|
||||
|
||||
current_time = int(time.time())
|
||||
return current_time - int(timestamp) <= dify_config.FILES_ACCESS_TIMEOUT
|
||||
@ -1,243 +0,0 @@
|
||||
import re
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, Union
|
||||
from urllib.parse import parse_qs, urlparse
|
||||
|
||||
import requests
|
||||
|
||||
from core.file.file_obj import FileBelongsTo, FileExtraConfig, FileTransferMethod, FileType, FileVar
|
||||
from extensions.ext_database import db
|
||||
from models.account import Account
|
||||
from models.model import EndUser, MessageFile, UploadFile
|
||||
from services.file_service import IMAGE_EXTENSIONS
|
||||
|
||||
|
||||
class MessageFileParser:
|
||||
def __init__(self, tenant_id: str, app_id: str) -> None:
|
||||
self.tenant_id = tenant_id
|
||||
self.app_id = app_id
|
||||
|
||||
def validate_and_transform_files_arg(
|
||||
self, files: Sequence[Mapping[str, Any]], file_extra_config: FileExtraConfig, user: Union[Account, EndUser]
|
||||
) -> list[FileVar]:
|
||||
"""
|
||||
validate and transform files arg
|
||||
|
||||
:param files:
|
||||
:param file_extra_config:
|
||||
:param user:
|
||||
:return:
|
||||
"""
|
||||
for file in files:
|
||||
if not isinstance(file, dict):
|
||||
raise ValueError("Invalid file format, must be dict")
|
||||
if not file.get("type"):
|
||||
raise ValueError("Missing file type")
|
||||
FileType.value_of(file.get("type"))
|
||||
if not file.get("transfer_method"):
|
||||
raise ValueError("Missing file transfer method")
|
||||
FileTransferMethod.value_of(file.get("transfer_method"))
|
||||
if file.get("transfer_method") == FileTransferMethod.REMOTE_URL.value:
|
||||
if not file.get("url"):
|
||||
raise ValueError("Missing file url")
|
||||
if not file.get("url").startswith("http"):
|
||||
raise ValueError("Invalid file url")
|
||||
if file.get("transfer_method") == FileTransferMethod.LOCAL_FILE.value and not file.get("upload_file_id"):
|
||||
raise ValueError("Missing file upload_file_id")
|
||||
if file.get("transform_method") == FileTransferMethod.TOOL_FILE.value and not file.get("tool_file_id"):
|
||||
raise ValueError("Missing file tool_file_id")
|
||||
|
||||
# transform files to file objs
|
||||
type_file_objs = self._to_file_objs(files, file_extra_config)
|
||||
|
||||
# validate files
|
||||
new_files = []
|
||||
for file_type, file_objs in type_file_objs.items():
|
||||
if file_type == FileType.IMAGE:
|
||||
# parse and validate files
|
||||
image_config = file_extra_config.image_config
|
||||
|
||||
# check if image file feature is enabled
|
||||
if not image_config:
|
||||
continue
|
||||
|
||||
# Validate number of files
|
||||
if len(files) > image_config["number_limits"]:
|
||||
raise ValueError(f"Number of image files exceeds the maximum limit {image_config['number_limits']}")
|
||||
|
||||
for file_obj in file_objs:
|
||||
# Validate transfer method
|
||||
if file_obj.transfer_method.value not in image_config["transfer_methods"]:
|
||||
raise ValueError(f"Invalid transfer method: {file_obj.transfer_method.value}")
|
||||
|
||||
# Validate file type
|
||||
if file_obj.type != FileType.IMAGE:
|
||||
raise ValueError(f"Invalid file type: {file_obj.type}")
|
||||
|
||||
if file_obj.transfer_method == FileTransferMethod.REMOTE_URL:
|
||||
# check remote url valid and is image
|
||||
result, error = self._check_image_remote_url(file_obj.url)
|
||||
if result is False:
|
||||
raise ValueError(error)
|
||||
elif file_obj.transfer_method == FileTransferMethod.LOCAL_FILE:
|
||||
# get upload file from upload_file_id
|
||||
upload_file = (
|
||||
db.session.query(UploadFile)
|
||||
.filter(
|
||||
UploadFile.id == file_obj.related_id,
|
||||
UploadFile.tenant_id == self.tenant_id,
|
||||
UploadFile.created_by == user.id,
|
||||
UploadFile.created_by_role == ("account" if isinstance(user, Account) else "end_user"),
|
||||
UploadFile.extension.in_(IMAGE_EXTENSIONS),
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
# check upload file is belong to tenant and user
|
||||
if not upload_file:
|
||||
raise ValueError("Invalid upload file")
|
||||
|
||||
new_files.append(file_obj)
|
||||
|
||||
# return all file objs
|
||||
return new_files
|
||||
|
||||
def transform_message_files(self, files: list[MessageFile], file_extra_config: FileExtraConfig):
|
||||
"""
|
||||
transform message files
|
||||
|
||||
:param files:
|
||||
:param file_extra_config:
|
||||
:return:
|
||||
"""
|
||||
# transform files to file objs
|
||||
type_file_objs = self._to_file_objs(files, file_extra_config)
|
||||
|
||||
# return all file objs
|
||||
return [file_obj for file_objs in type_file_objs.values() for file_obj in file_objs]
|
||||
|
||||
def _to_file_objs(
|
||||
self, files: list[Union[dict, MessageFile]], file_extra_config: FileExtraConfig
|
||||
) -> dict[FileType, list[FileVar]]:
|
||||
"""
|
||||
transform files to file objs
|
||||
|
||||
:param files:
|
||||
:param file_extra_config:
|
||||
:return:
|
||||
"""
|
||||
type_file_objs: dict[FileType, list[FileVar]] = {
|
||||
# Currently only support image
|
||||
FileType.IMAGE: []
|
||||
}
|
||||
|
||||
if not files:
|
||||
return type_file_objs
|
||||
|
||||
# group by file type and convert file args or message files to FileObj
|
||||
for file in files:
|
||||
if isinstance(file, MessageFile):
|
||||
if file.belongs_to == FileBelongsTo.ASSISTANT.value:
|
||||
continue
|
||||
|
||||
file_obj = self._to_file_obj(file, file_extra_config)
|
||||
if file_obj.type not in type_file_objs:
|
||||
continue
|
||||
|
||||
type_file_objs[file_obj.type].append(file_obj)
|
||||
|
||||
return type_file_objs
|
||||
|
||||
def _to_file_obj(self, file: Union[dict, MessageFile], file_extra_config: FileExtraConfig):
|
||||
"""
|
||||
transform file to file obj
|
||||
|
||||
:param file:
|
||||
:return:
|
||||
"""
|
||||
if isinstance(file, dict):
|
||||
transfer_method = FileTransferMethod.value_of(file.get("transfer_method"))
|
||||
if transfer_method != FileTransferMethod.TOOL_FILE:
|
||||
return FileVar(
|
||||
tenant_id=self.tenant_id,
|
||||
type=FileType.value_of(file.get("type")),
|
||||
transfer_method=transfer_method,
|
||||
url=file.get("url") if transfer_method == FileTransferMethod.REMOTE_URL else None,
|
||||
related_id=file.get("upload_file_id") if transfer_method == FileTransferMethod.LOCAL_FILE else None,
|
||||
extra_config=file_extra_config,
|
||||
)
|
||||
return FileVar(
|
||||
tenant_id=self.tenant_id,
|
||||
type=FileType.value_of(file.get("type")),
|
||||
transfer_method=transfer_method,
|
||||
url=None,
|
||||
related_id=file.get("tool_file_id"),
|
||||
extra_config=file_extra_config,
|
||||
)
|
||||
else:
|
||||
return FileVar(
|
||||
id=file.id,
|
||||
tenant_id=self.tenant_id,
|
||||
type=FileType.value_of(file.type),
|
||||
transfer_method=FileTransferMethod.value_of(file.transfer_method),
|
||||
url=file.url,
|
||||
related_id=file.upload_file_id or None,
|
||||
extra_config=file_extra_config,
|
||||
)
|
||||
|
||||
def _check_image_remote_url(self, url):
|
||||
try:
|
||||
headers = {
|
||||
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko)"
|
||||
" Chrome/91.0.4472.124 Safari/537.36"
|
||||
}
|
||||
|
||||
def is_s3_presigned_url(url):
|
||||
try:
|
||||
parsed_url = urlparse(url)
|
||||
if "amazonaws.com" not in parsed_url.netloc:
|
||||
return False
|
||||
query_params = parse_qs(parsed_url.query)
|
||||
|
||||
def check_presign_v2(query_params):
|
||||
required_params = ["Signature", "Expires"]
|
||||
for param in required_params:
|
||||
if param not in query_params:
|
||||
return False
|
||||
if not query_params["Expires"][0].isdigit():
|
||||
return False
|
||||
signature = query_params["Signature"][0]
|
||||
if not re.match(r"^[A-Za-z0-9+/]+={0,2}$", signature):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def check_presign_v4(query_params):
|
||||
required_params = ["X-Amz-Signature", "X-Amz-Expires"]
|
||||
for param in required_params:
|
||||
if param not in query_params:
|
||||
return False
|
||||
if not query_params["X-Amz-Expires"][0].isdigit():
|
||||
return False
|
||||
signature = query_params["X-Amz-Signature"][0]
|
||||
if not re.match(r"^[A-Za-z0-9+/]+={0,2}$", signature):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
return check_presign_v4(query_params) or check_presign_v2(query_params)
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
if is_s3_presigned_url(url):
|
||||
response = requests.get(url, headers=headers, allow_redirects=True)
|
||||
if response.status_code in {200, 304}:
|
||||
return True, ""
|
||||
|
||||
response = requests.head(url, headers=headers, allow_redirects=True)
|
||||
if response.status_code in {200, 304}:
|
||||
return True, ""
|
||||
else:
|
||||
return False, "URL does not exist."
|
||||
except requests.RequestException as e:
|
||||
return False, f"Error checking URL: {e}"
|
||||
@ -0,0 +1,140 @@
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
from core.model_runtime.entities.message_entities import ImagePromptMessageContent
|
||||
|
||||
from . import helpers
|
||||
from .constants import FILE_MODEL_IDENTITY
|
||||
from .enums import FileTransferMethod, FileType
|
||||
from .tool_file_parser import ToolFileParser
|
||||
|
||||
|
||||
class ImageConfig(BaseModel):
|
||||
"""
|
||||
NOTE: This part of validation is deprecated, but still used in app features "Image Upload".
|
||||
"""
|
||||
|
||||
number_limits: int = 0
|
||||
transfer_methods: Sequence[FileTransferMethod] = Field(default_factory=list)
|
||||
detail: ImagePromptMessageContent.DETAIL | None = None
|
||||
|
||||
|
||||
class FileExtraConfig(BaseModel):
|
||||
"""
|
||||
File Upload Entity.
|
||||
"""
|
||||
|
||||
image_config: Optional[ImageConfig] = None
|
||||
allowed_file_types: Sequence[FileType] = Field(default_factory=list)
|
||||
allowed_extensions: Sequence[str] = Field(default_factory=list)
|
||||
allowed_upload_methods: Sequence[FileTransferMethod] = Field(default_factory=list)
|
||||
number_limits: int = 0
|
||||
|
||||
|
||||
class File(BaseModel):
|
||||
dify_model_identity: str = FILE_MODEL_IDENTITY
|
||||
|
||||
id: Optional[str] = None # message file id
|
||||
tenant_id: str
|
||||
type: FileType
|
||||
transfer_method: FileTransferMethod
|
||||
remote_url: Optional[str] = None # remote url
|
||||
related_id: Optional[str] = None
|
||||
filename: Optional[str] = None
|
||||
extension: Optional[str] = Field(default=None, description="File extension, should contains dot")
|
||||
mime_type: Optional[str] = None
|
||||
size: int = -1
|
||||
_extra_config: FileExtraConfig | None = None
|
||||
|
||||
def to_dict(self) -> Mapping[str, str | int | None]:
|
||||
data = self.model_dump(mode="json")
|
||||
return {
|
||||
**data,
|
||||
"url": self.generate_url(),
|
||||
}
|
||||
|
||||
@property
|
||||
def markdown(self) -> str:
|
||||
url = self.generate_url()
|
||||
if self.type == FileType.IMAGE:
|
||||
text = f''
|
||||
else:
|
||||
text = f"[{self.filename or url}]({url})"
|
||||
|
||||
return text
|
||||
|
||||
def generate_url(self) -> Optional[str]:
|
||||
if self.type == FileType.IMAGE:
|
||||
if self.transfer_method == FileTransferMethod.REMOTE_URL:
|
||||
return self.remote_url
|
||||
elif self.transfer_method == FileTransferMethod.LOCAL_FILE:
|
||||
if self.related_id is None:
|
||||
raise ValueError("Missing file related_id")
|
||||
return helpers.get_signed_file_url(upload_file_id=self.related_id)
|
||||
elif self.transfer_method == FileTransferMethod.TOOL_FILE:
|
||||
assert self.related_id is not None
|
||||
assert self.extension is not None
|
||||
return ToolFileParser.get_tool_file_manager().sign_file(
|
||||
tool_file_id=self.related_id, extension=self.extension
|
||||
)
|
||||
else:
|
||||
if self.transfer_method == FileTransferMethod.REMOTE_URL:
|
||||
return self.remote_url
|
||||
elif self.transfer_method == FileTransferMethod.LOCAL_FILE:
|
||||
if self.related_id is None:
|
||||
raise ValueError("Missing file related_id")
|
||||
return helpers.get_signed_file_url(upload_file_id=self.related_id)
|
||||
elif self.transfer_method == FileTransferMethod.TOOL_FILE:
|
||||
assert self.related_id is not None
|
||||
assert self.extension is not None
|
||||
return ToolFileParser.get_tool_file_manager().sign_file(
|
||||
tool_file_id=self.related_id, extension=self.extension
|
||||
)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_after(self):
|
||||
match self.transfer_method:
|
||||
case FileTransferMethod.REMOTE_URL:
|
||||
if not self.remote_url:
|
||||
raise ValueError("Missing file url")
|
||||
if not isinstance(self.remote_url, str) or not self.remote_url.startswith("http"):
|
||||
raise ValueError("Invalid file url")
|
||||
case FileTransferMethod.LOCAL_FILE:
|
||||
if not self.related_id:
|
||||
raise ValueError("Missing file related_id")
|
||||
case FileTransferMethod.TOOL_FILE:
|
||||
if not self.related_id:
|
||||
raise ValueError("Missing file related_id")
|
||||
|
||||
# Validate the extra config.
|
||||
if not self._extra_config:
|
||||
return self
|
||||
|
||||
if self._extra_config.allowed_file_types:
|
||||
if self.type not in self._extra_config.allowed_file_types and self.type != FileType.CUSTOM:
|
||||
raise ValueError(f"Invalid file type: {self.type}")
|
||||
|
||||
if self._extra_config.allowed_extensions and self.extension not in self._extra_config.allowed_extensions:
|
||||
raise ValueError(f"Invalid file extension: {self.extension}")
|
||||
|
||||
if (
|
||||
self._extra_config.allowed_upload_methods
|
||||
and self.transfer_method not in self._extra_config.allowed_upload_methods
|
||||
):
|
||||
raise ValueError(f"Invalid transfer method: {self.transfer_method}")
|
||||
|
||||
match self.type:
|
||||
case FileType.IMAGE:
|
||||
# NOTE: This part of validation is deprecated, but still used in app features "Image Upload".
|
||||
if not self._extra_config.image_config:
|
||||
return self
|
||||
# TODO: skip check if transfer_methods is empty, because many test cases are not setting this field
|
||||
if (
|
||||
self._extra_config.image_config.transfer_methods
|
||||
and self.transfer_method not in self._extra_config.image_config.transfer_methods
|
||||
):
|
||||
raise ValueError(f"Invalid transfer method: {self.transfer_method}")
|
||||
|
||||
return self
|
||||
@ -1,79 +0,0 @@
|
||||
import base64
|
||||
import hashlib
|
||||
import hmac
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
from configs import dify_config
|
||||
from extensions.ext_storage import storage
|
||||
|
||||
IMAGE_EXTENSIONS = ["jpg", "jpeg", "png", "webp", "gif", "svg"]
|
||||
IMAGE_EXTENSIONS.extend([ext.upper() for ext in IMAGE_EXTENSIONS])
|
||||
|
||||
|
||||
class UploadFileParser:
|
||||
@classmethod
|
||||
def get_image_data(cls, upload_file, force_url: bool = False) -> Optional[str]:
|
||||
if not upload_file:
|
||||
return None
|
||||
|
||||
if upload_file.extension not in IMAGE_EXTENSIONS:
|
||||
return None
|
||||
|
||||
if dify_config.MULTIMODAL_SEND_IMAGE_FORMAT == "url" or force_url:
|
||||
return cls.get_signed_temp_image_url(upload_file.id)
|
||||
else:
|
||||
# get image file base64
|
||||
try:
|
||||
data = storage.load(upload_file.key)
|
||||
except FileNotFoundError:
|
||||
logging.error(f"File not found: {upload_file.key}")
|
||||
return None
|
||||
|
||||
encoded_string = base64.b64encode(data).decode("utf-8")
|
||||
return f"data:{upload_file.mime_type};base64,{encoded_string}"
|
||||
|
||||
@classmethod
|
||||
def get_signed_temp_image_url(cls, upload_file_id) -> str:
|
||||
"""
|
||||
get signed url from upload file
|
||||
|
||||
:param upload_file: UploadFile object
|
||||
:return:
|
||||
"""
|
||||
base_url = dify_config.FILES_URL
|
||||
image_preview_url = f"{base_url}/files/{upload_file_id}/image-preview"
|
||||
|
||||
timestamp = str(int(time.time()))
|
||||
nonce = os.urandom(16).hex()
|
||||
data_to_sign = f"image-preview|{upload_file_id}|{timestamp}|{nonce}"
|
||||
secret_key = dify_config.SECRET_KEY.encode()
|
||||
sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
|
||||
encoded_sign = base64.urlsafe_b64encode(sign).decode()
|
||||
|
||||
return f"{image_preview_url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}"
|
||||
|
||||
@classmethod
|
||||
def verify_image_file_signature(cls, upload_file_id: str, timestamp: str, nonce: str, sign: str) -> bool:
|
||||
"""
|
||||
verify signature
|
||||
|
||||
:param upload_file_id: file id
|
||||
:param timestamp: timestamp
|
||||
:param nonce: nonce
|
||||
:param sign: signature
|
||||
:return:
|
||||
"""
|
||||
data_to_sign = f"image-preview|{upload_file_id}|{timestamp}|{nonce}"
|
||||
secret_key = dify_config.SECRET_KEY.encode()
|
||||
recalculated_sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
|
||||
recalculated_encoded_sign = base64.urlsafe_b64encode(recalculated_sign).decode()
|
||||
|
||||
# verify signature
|
||||
if sign != recalculated_encoded_sign:
|
||||
return False
|
||||
|
||||
current_time = int(time.time())
|
||||
return current_time - int(timestamp) <= dify_config.FILES_ACCESS_TIMEOUT
|
||||
@ -0,0 +1,38 @@
|
||||
from .llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
|
||||
from .message_entities import (
|
||||
AssistantPromptMessage,
|
||||
AudioPromptMessageContent,
|
||||
ImagePromptMessageContent,
|
||||
PromptMessage,
|
||||
PromptMessageContent,
|
||||
PromptMessageContentType,
|
||||
PromptMessageRole,
|
||||
PromptMessageTool,
|
||||
SystemPromptMessage,
|
||||
TextPromptMessageContent,
|
||||
ToolPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from .model_entities import ModelPropertyKey
|
||||
|
||||
__all__ = [
|
||||
"ImagePromptMessageContent",
|
||||
"PromptMessage",
|
||||
"PromptMessageRole",
|
||||
"LLMUsage",
|
||||
"ModelPropertyKey",
|
||||
"AssistantPromptMessage",
|
||||
"PromptMessage",
|
||||
"PromptMessageContent",
|
||||
"PromptMessageRole",
|
||||
"SystemPromptMessage",
|
||||
"TextPromptMessageContent",
|
||||
"UserPromptMessage",
|
||||
"PromptMessageTool",
|
||||
"ToolPromptMessage",
|
||||
"PromptMessageContentType",
|
||||
"LLMResult",
|
||||
"LLMResultChunk",
|
||||
"LLMResultChunkDelta",
|
||||
"AudioPromptMessageContent",
|
||||
]
|
||||
@ -0,0 +1,44 @@
|
||||
model: gpt-4o-audio-preview
|
||||
label:
|
||||
zh_Hans: gpt-4o-audio-preview
|
||||
en_US: gpt-4o-audio-preview
|
||||
model_type: llm
|
||||
features:
|
||||
- multi-tool-call
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
- vision
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 128000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: presence_penalty
|
||||
use_template: presence_penalty
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 512
|
||||
min: 1
|
||||
max: 4096
|
||||
- name: response_format
|
||||
label:
|
||||
zh_Hans: 回复格式
|
||||
en_US: Response Format
|
||||
type: string
|
||||
help:
|
||||
zh_Hans: 指定模型必须输出的格式
|
||||
en_US: specifying the format that the model must output
|
||||
required: false
|
||||
options:
|
||||
- text
|
||||
- json_object
|
||||
pricing:
|
||||
input: '5.00'
|
||||
output: '15.00'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
||||
@ -0,0 +1,24 @@
|
||||
<svg width="100" height="100" viewBox="0 0 100 100" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<rect width="100" height="100" rx="20" fill="#4A90E2" />
|
||||
<path
|
||||
d="M50 25C40.6 25 33 32.6 33 42V58C33 67.4 40.6 75 50 75C59.4 75 67 67.4 67 58V42C67 32.6 59.4 25 50 25ZM61 58C61 64.1 56.1 69 50 69C43.9 69 39 64.1 39 58V42C39 35.9 43.9 31 50 31C56.1 31 61 35.9 61 42V58Z"
|
||||
fill="white" />
|
||||
<path d="M50 37C47.2 37 45 39.2 45 42V58C45 60.8 47.2 63 50 63C52.8 63 55 60.8 55 58V42C55 39.2 52.8 37 50 37Z"
|
||||
fill="white" />
|
||||
<path
|
||||
d="M73 49H69V58C69 68.5 60.5 77 50 77C39.5 77 31 68.5 31 58V49H27V58C27 70.7 37.3 81 50 81C62.7 81 73 70.7 73 58V49Z"
|
||||
fill="white" />
|
||||
<path d="M50 85C51.1 85 52 84.1 52 83V81H48V83C48 84.1 48.9 85 50 85Z" fill="white" />
|
||||
<path
|
||||
d="M35 45C36.1046 45 37 44.1046 37 43C37 41.8954 36.1046 41 35 41C33.8954 41 33 41.8954 33 43C33 44.1046 33.8954 45 35 45Z"
|
||||
fill="white" />
|
||||
<path
|
||||
d="M35 55C36.1046 55 37 54.1046 37 53C37 51.8954 36.1046 51 35 51C33.8954 51 33 51.8954 33 53C33 54.1046 33.8954 55 35 55Z"
|
||||
fill="white" />
|
||||
<path
|
||||
d="M65 45C66.1046 45 67 44.1046 67 43C67 41.8954 66.1046 41 65 41C63.8954 41 63 41.8954 63 43C63 44.1046 63.8954 45 65 45Z"
|
||||
fill="white" />
|
||||
<path
|
||||
d="M65 55C66.1046 55 67 54.1046 67 53C67 51.8954 66.1046 51 65 51C63.8954 51 63 51.8954 63 53C63 54.1046 63.8954 55 65 55Z"
|
||||
fill="white" />
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 1.4 KiB |
@ -0,0 +1,33 @@
|
||||
from typing import Any
|
||||
|
||||
import openai
|
||||
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
|
||||
|
||||
class PodcastGeneratorProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
|
||||
tts_service = credentials.get("tts_service")
|
||||
api_key = credentials.get("api_key")
|
||||
|
||||
if not tts_service:
|
||||
raise ToolProviderCredentialValidationError("TTS service is not specified")
|
||||
|
||||
if not api_key:
|
||||
raise ToolProviderCredentialValidationError("API key is missing")
|
||||
|
||||
if tts_service == "openai":
|
||||
self._validate_openai_credentials(api_key)
|
||||
else:
|
||||
raise ToolProviderCredentialValidationError(f"Unsupported TTS service: {tts_service}")
|
||||
|
||||
def _validate_openai_credentials(self, api_key: str) -> None:
|
||||
client = openai.OpenAI(api_key=api_key)
|
||||
try:
|
||||
# We're using a simple API call to validate the credentials
|
||||
client.models.list()
|
||||
except openai.AuthenticationError:
|
||||
raise ToolProviderCredentialValidationError("Invalid OpenAI API key")
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(f"Error validating OpenAI API key: {str(e)}")
|
||||
@ -0,0 +1,34 @@
|
||||
identity:
|
||||
author: Dify
|
||||
name: podcast_generator
|
||||
label:
|
||||
en_US: Podcast Generator
|
||||
zh_Hans: 播客生成器
|
||||
description:
|
||||
en_US: Generate podcast audio using Text-to-Speech services
|
||||
zh_Hans: 使用文字转语音服务生成播客音频
|
||||
icon: icon.svg
|
||||
credentials_for_provider:
|
||||
tts_service:
|
||||
type: select
|
||||
required: true
|
||||
label:
|
||||
en_US: TTS Service
|
||||
zh_Hans: TTS 服务
|
||||
placeholder:
|
||||
en_US: Select a TTS service
|
||||
zh_Hans: 选择一个 TTS 服务
|
||||
options:
|
||||
- label:
|
||||
en_US: OpenAI TTS
|
||||
zh_Hans: OpenAI TTS
|
||||
value: openai
|
||||
api_key:
|
||||
type: secret-input
|
||||
required: true
|
||||
label:
|
||||
en_US: API Key
|
||||
zh_Hans: API 密钥
|
||||
placeholder:
|
||||
en_US: Enter your TTS service API key
|
||||
zh_Hans: 输入您的 TTS 服务 API 密钥
|
||||
@ -0,0 +1,100 @@
|
||||
import concurrent.futures
|
||||
import io
|
||||
import random
|
||||
from typing import Any, Literal, Optional, Union
|
||||
|
||||
import openai
|
||||
from pydub import AudioSegment
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.errors import ToolParameterValidationError, ToolProviderCredentialValidationError
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class PodcastAudioGeneratorTool(BuiltinTool):
|
||||
@staticmethod
|
||||
def _generate_silence(duration: float):
|
||||
# Generate silent WAV data using pydub
|
||||
silence = AudioSegment.silent(duration=int(duration * 1000)) # pydub uses milliseconds
|
||||
return silence
|
||||
|
||||
@staticmethod
|
||||
def _generate_audio_segment(
|
||||
client: openai.OpenAI,
|
||||
line: str,
|
||||
voice: Literal["alloy", "echo", "fable", "onyx", "nova", "shimmer"],
|
||||
index: int,
|
||||
) -> tuple[int, Union[AudioSegment, str], Optional[AudioSegment]]:
|
||||
try:
|
||||
response = client.audio.speech.create(model="tts-1", voice=voice, input=line.strip(), response_format="wav")
|
||||
audio = AudioSegment.from_wav(io.BytesIO(response.content))
|
||||
silence_duration = random.uniform(0.1, 1.5)
|
||||
silence = PodcastAudioGeneratorTool._generate_silence(silence_duration)
|
||||
return index, audio, silence
|
||||
except Exception as e:
|
||||
return index, f"Error generating audio: {str(e)}", None
|
||||
|
||||
def _invoke(
|
||||
self, user_id: str, tool_parameters: dict[str, Any]
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
# Extract parameters
|
||||
script = tool_parameters.get("script", "")
|
||||
host1_voice = tool_parameters.get("host1_voice")
|
||||
host2_voice = tool_parameters.get("host2_voice")
|
||||
|
||||
# Split the script into lines
|
||||
script_lines = [line for line in script.split("\n") if line.strip()]
|
||||
|
||||
# Ensure voices are provided
|
||||
if not host1_voice or not host2_voice:
|
||||
raise ToolParameterValidationError("Host voices are required")
|
||||
|
||||
# Get OpenAI API key from credentials
|
||||
if not self.runtime or not self.runtime.credentials:
|
||||
raise ToolProviderCredentialValidationError("Tool runtime or credentials are missing")
|
||||
api_key = self.runtime.credentials.get("api_key")
|
||||
if not api_key:
|
||||
raise ToolProviderCredentialValidationError("OpenAI API key is missing")
|
||||
|
||||
# Initialize OpenAI client
|
||||
client = openai.OpenAI(api_key=api_key)
|
||||
|
||||
# Create a thread pool
|
||||
max_workers = 5
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
futures = []
|
||||
for i, line in enumerate(script_lines):
|
||||
voice = host1_voice if i % 2 == 0 else host2_voice
|
||||
future = executor.submit(self._generate_audio_segment, client, line, voice, i)
|
||||
futures.append(future)
|
||||
|
||||
# Collect results
|
||||
audio_segments: list[Any] = [None] * len(script_lines)
|
||||
for future in concurrent.futures.as_completed(futures):
|
||||
index, audio, silence = future.result()
|
||||
if isinstance(audio, str): # Error occurred
|
||||
return self.create_text_message(audio)
|
||||
audio_segments[index] = (audio, silence)
|
||||
|
||||
# Combine audio segments in the correct order
|
||||
combined_audio = AudioSegment.empty()
|
||||
for i, (audio, silence) in enumerate(audio_segments):
|
||||
if audio:
|
||||
combined_audio += audio
|
||||
if i < len(audio_segments) - 1 and silence:
|
||||
combined_audio += silence
|
||||
|
||||
# Export the combined audio to a WAV file in memory
|
||||
buffer = io.BytesIO()
|
||||
combined_audio.export(buffer, format="wav")
|
||||
wav_bytes = buffer.getvalue()
|
||||
|
||||
# Create a blob message with the combined audio
|
||||
return [
|
||||
self.create_text_message("Audio generated successfully"),
|
||||
self.create_blob_message(
|
||||
blob=wav_bytes,
|
||||
meta={"mime_type": "audio/x-wav"},
|
||||
save_as=self.VariableKey.AUDIO,
|
||||
),
|
||||
]
|
||||
@ -0,0 +1,95 @@
|
||||
identity:
|
||||
name: podcast_audio_generator
|
||||
author: Dify
|
||||
label:
|
||||
en_US: Podcast Audio Generator
|
||||
zh_Hans: 播客音频生成器
|
||||
description:
|
||||
human:
|
||||
en_US: Generate a podcast audio file from a script with two alternating voices using OpenAI's TTS service.
|
||||
zh_Hans: 使用 OpenAI 的 TTS 服务,从包含两个交替声音的脚本生成播客音频文件。
|
||||
llm: This tool converts a prepared podcast script into an audio file using OpenAI's Text-to-Speech service, with two specified voices for alternating hosts.
|
||||
parameters:
|
||||
- name: script
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: Podcast Script
|
||||
zh_Hans: 播客脚本
|
||||
human_description:
|
||||
en_US: A string containing alternating lines for two hosts, separated by newline characters.
|
||||
zh_Hans: 包含两位主持人交替台词的字符串,每行用换行符分隔。
|
||||
llm_description: A string representing the script, with alternating lines for two hosts separated by newline characters.
|
||||
form: llm
|
||||
- name: host1_voice
|
||||
type: select
|
||||
required: true
|
||||
label:
|
||||
en_US: Host 1 Voice
|
||||
zh_Hans: 主持人1 音色
|
||||
human_description:
|
||||
en_US: The voice for the first host.
|
||||
zh_Hans: 第一位主持人的音色。
|
||||
llm_description: The voice identifier for the first host's voice.
|
||||
options:
|
||||
- label:
|
||||
en_US: Alloy
|
||||
zh_Hans: Alloy
|
||||
value: alloy
|
||||
- label:
|
||||
en_US: Echo
|
||||
zh_Hans: Echo
|
||||
value: echo
|
||||
- label:
|
||||
en_US: Fable
|
||||
zh_Hans: Fable
|
||||
value: fable
|
||||
- label:
|
||||
en_US: Onyx
|
||||
zh_Hans: Onyx
|
||||
value: onyx
|
||||
- label:
|
||||
en_US: Nova
|
||||
zh_Hans: Nova
|
||||
value: nova
|
||||
- label:
|
||||
en_US: Shimmer
|
||||
zh_Hans: Shimmer
|
||||
value: shimmer
|
||||
form: form
|
||||
- name: host2_voice
|
||||
type: select
|
||||
required: true
|
||||
label:
|
||||
en_US: Host 2 Voice
|
||||
zh_Hans: 主持人2 音色
|
||||
human_description:
|
||||
en_US: The voice for the second host.
|
||||
zh_Hans: 第二位主持人的音色。
|
||||
llm_description: The voice identifier for the second host's voice.
|
||||
options:
|
||||
- label:
|
||||
en_US: Alloy
|
||||
zh_Hans: Alloy
|
||||
value: alloy
|
||||
- label:
|
||||
en_US: Echo
|
||||
zh_Hans: Echo
|
||||
value: echo
|
||||
- label:
|
||||
en_US: Fable
|
||||
zh_Hans: Fable
|
||||
value: fable
|
||||
- label:
|
||||
en_US: Onyx
|
||||
zh_Hans: Onyx
|
||||
value: onyx
|
||||
- label:
|
||||
en_US: Nova
|
||||
zh_Hans: Nova
|
||||
value: nova
|
||||
- label:
|
||||
en_US: Shimmer
|
||||
zh_Hans: Shimmer
|
||||
value: shimmer
|
||||
form: form
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue